Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
DanPuzzuoli committed Dec 18, 2023
1 parent 8908ef4 commit 298b606
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 40 deletions.
6 changes: 4 additions & 2 deletions qiskit_dynamics/solvers/fixed_step_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Custom fixed step solvers.
"""

from typing import Callable, Optional, Union, Tuple, List
from typing import Callable, Optional, Tuple
from warnings import warn
import numpy as np
from scipy.integrate._ivp.ivp import OdeResult
Expand Down Expand Up @@ -611,7 +611,9 @@ def fixed_step_lmde_solver_parallel_template_jax(
return trim_t_results(results, t_eval)


def get_fixed_step_sizes(t_span: ArrayLike, t_eval: ArrayLike, max_dt: float) -> Tuple[ArrayLike, ArrayLike, ArrayLike]:
def get_fixed_step_sizes(
t_span: ArrayLike, t_eval: ArrayLike, max_dt: float
) -> Tuple[ArrayLike, ArrayLike, ArrayLike]:
"""Merge ``t_span`` and ``t_eval``, and determine the number of time steps and
and step sizes (no larger than ``max_dt``) required to fixed-step integrate between
each time point.
Expand Down
2 changes: 0 additions & 2 deletions qiskit_dynamics/solvers/jax_odeint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""

from typing import Callable, Optional
import numpy as np
from scipy.integrate._ivp.ivp import OdeResult

from qiskit_dynamics import DYNAMICS_NUMPY as unp
Expand All @@ -28,7 +27,6 @@
from .solver_utils import merge_t_args_jax, trim_t_results_jax

try:
import jax.numpy as jnp
from jax.experimental.ode import odeint
except ImportError:
pass
Expand Down
18 changes: 6 additions & 12 deletions qiskit_dynamics/solvers/solver_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
Solver functions.
"""

from typing import Optional, Union, Callable, Tuple, List, TypeVar
from typing import Optional, Union, Callable, Tuple, TypeVar
from warnings import warn

from scipy.integrate import OdeSolver

from scipy.integrate._ivp.ivp import OdeResult # pylint: disable=unused-import
from scipy.integrate._ivp.ivp import OdeResult

from qiskit import QiskitError
from qiskit_dynamics import DYNAMICS_NUMPY as unp
Expand Down Expand Up @@ -132,7 +132,7 @@ def solve_ode(
method: Optional[Union[str, OdeSolver, DiffraxAbstractSolver]] = "DOP853",
t_eval: Optional[ArrayLike] = None,
**kwargs,
):
) -> OdeResult:
r"""General interface for solving Ordinary Differential Equations (ODEs).
ODEs are differential equations of the form
Expand Down Expand Up @@ -223,7 +223,7 @@ def solve_lmde(
method: Optional[Union[str, OdeSolver, DiffraxAbstractSolver]] = "DOP853",
t_eval: Optional[ArrayLike] = None,
**kwargs,
):
) -> OdeResult:
r"""General interface for solving Linear Matrix Differential Equations (LMDEs)
in standard form.
Expand Down Expand Up @@ -393,10 +393,7 @@ def setup_generator_model_rhs_y0_in_frame_basis(

# if model not specified in frame basis, transform initial state into frame basis
if not model_in_frame_basis:
if (
isinstance(generator_model, LindbladModel)
and generator_model.vectorized
):
if isinstance(generator_model, LindbladModel) and generator_model.vectorized:
if generator_model.rotating_frame.frame_basis is not None:
y0 = generator_model.rotating_frame.vectorized_frame_basis_adjoint @ y0
elif isinstance(generator_model, LindbladModel):
Expand Down Expand Up @@ -438,10 +435,7 @@ def results_y_out_of_frame_basis(
if y0_ndim == 1:
results_y = results_y.T

if (
isinstance(generator_model, LindbladModel)
and generator_model.vectorized
):
if isinstance(generator_model, LindbladModel) and generator_model.vectorized:
if generator_model.rotating_frame.frame_basis is not None:
results_y = generator_model.rotating_frame.vectorized_frame_basis @ results_y
elif isinstance(generator_model, LindbladModel):
Expand Down
10 changes: 3 additions & 7 deletions qiskit_dynamics/solvers/solver_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Utility functions for solvers.
"""

from typing import Optional, Union, List, Tuple, Callable
from typing import Optional, List, Tuple, Callable
import numpy as np
from scipy.integrate._ivp.ivp import OdeResult

Expand All @@ -43,9 +43,7 @@ def is_lindblad_model_not_vectorized(obj: any) -> bool:
return isinstance(obj, LindbladModel) and not obj.vectorized


def merge_t_args(
t_span: ArrayLike, t_eval: Optional[ArrayLike] = None
) -> np.ndarray:
def merge_t_args(t_span: ArrayLike, t_eval: Optional[ArrayLike] = None) -> np.ndarray:
"""Merge ``t_span`` and ``t_eval`` into a single array.
Validition is similar to scipy ``solve_ivp``: ``t_eval`` must be contained in ``t_span``, and be
Expand Down Expand Up @@ -121,9 +119,7 @@ def trim_t_results(
return results


def merge_t_args_jax(
t_span: ArrayLike, t_eval: Optional[ArrayLike] = None
) -> jnp.ndarray:
def merge_t_args_jax(t_span: ArrayLike, t_eval: Optional[ArrayLike] = None) -> jnp.ndarray:
"""JAX-compilable version of merge_t_args.
Rather than raise errors, sets return values to ``jnp.nan`` to signal errors.
Expand Down
2 changes: 1 addition & 1 deletion test/dynamics/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def setUpClass(cls):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
import diffrax # pylint: disable=import-outside-toplevel,unused-import

# pylint: disable=import-outside-toplevel
import jax

Expand Down
6 changes: 2 additions & 4 deletions test/dynamics/solvers/test_fixed_step_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ def constant_rhs(t, y=None):
self.constant_rhs = constant_rhs

Y = np.array([[0.0, -1j], [1j, 0.0]])
self.linear_generator = (
lambda t: -1j * (X + t * Y)
)
self.linear_generator = lambda t: -1j * (X + t * Y)

def linear_rhs(t, y=None):
if y is None:
Expand Down Expand Up @@ -99,7 +97,7 @@ def linear_rhs(t, y=None):
)

def random_generator(t):
return unp.sin(t) * rand_ops[0] + (t**5) * rand_ops[1] + unp.exp(t) * rand_ops[2]
return unp.sin(t) * rand_ops[0] + (t**5) * rand_ops[1] + unp.exp(t) * rand_ops[2]

self.random_generator = random_generator

Expand Down
4 changes: 3 additions & 1 deletion test/dynamics/solvers/test_jax_odeint.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def test_transformations_w_t_span_t_eval_no_overlap(self):
def func(t_s, t_e):
results = jax_odeint(self.simple_rhs, t_s, y0, t_eval=t_e, atol=1e-10, rtol=1e-10)
return results.t, results.y

t, y = jit(func)(t_span, t_eval)

self.assertAllClose(t_eval, t)
Expand Down Expand Up @@ -165,7 +166,8 @@ def sim_function(a):
return results.y[-1].real.sum()

self.assertAllClose(
jit(grad(lambda a: sim_function(a).real.sum()))(2.0), 4 * (0.5 + (2.0**3 - 1.0**3) / 3)
jit(grad(lambda a: sim_function(a).real.sum()))(2.0),
4 * (0.5 + (2.0**3 - 1.0**3) / 3),
)

def test_empty_integration(self):
Expand Down
21 changes: 15 additions & 6 deletions test/dynamics/solvers/test_solver_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
# pylint: disable=invalid-name
# pylint: disable=invalid-name,no-member

"""Standardized test cases for results of calls to solve_lmde and solve_ode,
for both variable and fixed-step methods. These tests set up common test cases
Expand All @@ -31,7 +31,7 @@
from qiskit_dynamics.signals import Signal, DiscreteSignal
from qiskit_dynamics import solve_ode, solve_lmde

from ..common import QiskitDynamicsTestCase, DiffraxTestBase, JAXTestBase, test_array_backends
from ..common import test_array_backends, DiffraxTestBase

try:
from diffrax import PIDController, Tsit5, Dopri5
Expand Down Expand Up @@ -69,7 +69,9 @@ def basic_rhs(t, y=None):
self.r = 0.1
signals = [self.w, Signal(lambda t: 1.0, self.w)]
operators = [-1j * 2 * np.pi * self.Z / 2, -1j * 2 * np.pi * self.r * self.X / 2]
self.simple_model = GeneratorModel(operators=operators, signals=signals, array_library=self.array_library())
self.simple_model = GeneratorModel(
operators=operators, signals=signals, array_library=self.array_library()
)

# construct randomized RHS
dim = 7
Expand Down Expand Up @@ -99,7 +101,7 @@ def basic_rhs(t, y=None):
signals=[self.pseudo_random_signal],
static_operator=static_operator,
rotating_frame=frame_op,
array_library=self.array_library()
array_library=self.array_library(),
)

# simulate directly out of frame
Expand Down Expand Up @@ -227,12 +229,15 @@ def test_pseudo_random_jit_grad(self):

def func(a):
self.pseudo_random_model.signals = [Signal(a, carrier_freq=1.0)]
results = self.solve(self.pseudo_random_model, t_span=[0.0, 0.1], y0=self.pseudo_random_y0)
results = self.solve(
self.pseudo_random_model, t_span=[0.0, 0.1], y0=self.pseudo_random_y0
)
self.pseudo_random_model.signals = None
return results.y[-1]

# verify we can jit
from jax import jit

self.assertAllClose(jit(func)(1.0), func(1.0))

# just verify that this runs without error
Expand All @@ -258,6 +263,7 @@ def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_ode, **kwargs):
def is_ode_method(self):
return True


@partial(test_array_backends, array_libraries=["jax"])
class Testjax_RK4(TestSolverMethodJAX):
"""Test class for jax_RK4_solver."""
Expand Down Expand Up @@ -314,6 +320,7 @@ def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_lmde, **kwargs):
**kwargs,
)


@partial(test_array_backends, array_libraries=["numpy"])
class Testscipy_expm_magnus2(TestSolverMethod):
"""Test class for scipy_expm_solver with magnus_order==2."""
Expand All @@ -330,6 +337,7 @@ def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_lmde, **kwargs):
**kwargs,
)


@partial(test_array_backends, array_libraries=["numpy"])
class Testscipy_expm_magnus3(TestSolverMethod):
"""Test class for scipy_expm_solver with magnus_order==3."""
Expand Down Expand Up @@ -416,6 +424,7 @@ def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_lmde, **kwargs):
**kwargs,
)


test_array_backends(Testlanczos_diag, array_libraries=["scipy_sparse"])


Expand Down Expand Up @@ -676,4 +685,4 @@ def solve(self, rhs, t_span, y0, t_eval=None, solver_func=solve_ode, **kwargs):

@property
def is_ode_method(self):
return True
return True
12 changes: 7 additions & 5 deletions test/dynamics/solvers/test_solver_functions_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def setUp(self):
hamiltonian_signals=[Signal(1.0, 5.0)],
static_hamiltonian=Operator.from_label("Z"),
static_dissipators=[Operator.from_label("Y")],
vectorized=True
vectorized=True,
)

self.frame_op = 1.2 * Operator.from_label("X") - 3.132 * Operator.from_label("Y")
Expand All @@ -197,15 +197,15 @@ def setUp(self):
operators=[Operator.from_label("X")],
signals=[Signal(1.0, 5.0)],
static_operator=Operator.from_label("Z"),
rotating_frame=self.frame_op
rotating_frame=self.frame_op,
)

self.rf_lindblad_model = LindbladModel(
hamiltonian_operators=[Operator.from_label("X")],
hamiltonian_signals=[Signal(1.0, 5.0)],
static_hamiltonian=Operator.from_label("Z"),
static_dissipators=[Operator.from_label("Y")],
rotating_frame=self.frame_op
rotating_frame=self.frame_op,
)

self.rf_vec_lindblad_model = LindbladModel(
Expand All @@ -214,7 +214,7 @@ def setUp(self):
static_hamiltonian=Operator.from_label("Z"),
static_dissipators=[Operator.from_label("Y")],
rotating_frame=self.frame_op,
vectorized=True
vectorized=True,
)

def test_hamiltonian_setup_no_frame(self):
Expand Down Expand Up @@ -300,7 +300,9 @@ def test_vectorized_lindblad_setup(self):
"""Test functions for vectorized LindbladModel with frame."""

y0 = np.array([[3.43, 1.31], [3.0, 1.23]]).flatten()
gen, rhs, new_y0, _ = setup_generator_model_rhs_y0_in_frame_basis(self.rf_vec_lindblad_model, y0)
gen, rhs, new_y0, _ = setup_generator_model_rhs_y0_in_frame_basis(
self.rf_vec_lindblad_model, y0
)

# expect nothing to happen
self.assertAllClose(np.kron(self.Uadj.conj(), self.Uadj) @ y0, new_y0)
Expand Down

0 comments on commit 298b606

Please sign in to comment.