Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

changed np.nparray to np.typing.NDArray #4526

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
7 changes: 4 additions & 3 deletions benchmarks/different_model_options.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import numpy.typing as npt
import pybamm
from benchmarks.benchmark_utils import set_random_seed
import numpy as np


def compute_discretisation(model, param):
Expand Down Expand Up @@ -33,8 +34,8 @@ def build_model(parameter, model_, option, value):
class SolveModel:
solver: pybamm.BaseSolver
model: pybamm.BaseModel
t_eval: np.ndarray
t_interp: np.ndarray | None
t_eval: npt.NDArray
t_interp: npt.NDArray | None

def solve_setup(self, parameter, model_, option, value, solver_class):
import importlib
Expand Down
11 changes: 6 additions & 5 deletions benchmarks/time_solve_models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import numpy as np
import numpy.typing as npt
# Write the benchmarking functions here.
# See "Writing benchmarks" in the asv docs for more information.

import pybamm
from benchmarks.benchmark_utils import set_random_seed
import numpy as np


def solve_model_once(model, solver, t_eval, t_interp):
Expand All @@ -30,8 +31,8 @@ class TimeSolveSPM:
)
model: pybamm.BaseModel
solver: pybamm.BaseSolver
t_eval: np.ndarray
t_interp: np.ndarray | None
t_eval: npt.NDArray
t_interp: npt.NDArray | None

def setup(self, solve_first, parameters, solver_class):
set_random_seed()
Expand Down Expand Up @@ -96,7 +97,7 @@ class TimeSolveSPMe:
)
model: pybamm.BaseModel
solver: pybamm.BaseSolver
t_eval: np.ndarray
t_eval: npt.NDArray

def setup(self, solve_first, parameters, solver_class):
set_random_seed()
Expand Down Expand Up @@ -160,7 +161,7 @@ class TimeSolveDFN:
)
model: pybamm.BaseModel
solver: pybamm.BaseSolver
t_eval: np.ndarray
t_eval: npt.NDArray

def setup(self, solve_first, parameters, solver_class):
set_random_seed()
Expand Down
6 changes: 4 additions & 2 deletions src/pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
import numpy.typing as npt

#
# Interface for discretisation
#
import pybamm
import numpy as np
from collections import defaultdict, OrderedDict
from scipy.sparse import block_diag, csc_matrix, csr_matrix
from scipy.sparse.linalg import inv
Expand Down Expand Up @@ -1012,7 +1014,7 @@ def check_initial_conditions(self, model):
# Individual
for var, eqn in model.initial_conditions.items():
ic_eval = eqn.evaluate(t=0, inputs="shape test")
if not isinstance(ic_eval, np.ndarray):
if not isinstance(ic_eval, npt.NDArray):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

raise pybamm.ModelError(
"initial conditions must be numpy array after discretisation but "
f"they are {type(ic_eval)} for variable '{var}'."
Expand Down
8 changes: 5 additions & 3 deletions src/pybamm/experiment/step/base_step.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
import numpy.typing as npt

#
# Private classes and functions for experiment steps
#
import pybamm
import numpy as np
from datetime import datetime
from .step_termination import _read_termination
import numbers
Expand Down Expand Up @@ -74,7 +76,7 @@ def __init__(
self.input_duration = duration
self.input_value = value
# Check if drive cycle
is_drive_cycle = isinstance(value, np.ndarray)
is_drive_cycle = isinstance(value, npt.NDArray)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be changed as we are checking if value is an instance of np.ndarray. We could have changed it if we were checking the type of the value, but we don't want that.

is_python_function = callable(value)
if is_drive_cycle:
if value.ndim != 2 or value.shape[1] != 2:
Expand Down Expand Up @@ -260,7 +262,7 @@ def default_duration(self, value):
Default duration for the step is one day (24 hours) or the duration of the
drive cycle
"""
if isinstance(value, np.ndarray):
if isinstance(value, npt.NDArray):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

t = value[:, 0]
return t[-1]
else:
Expand Down
12 changes: 7 additions & 5 deletions src/pybamm/expression_tree/array.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
import numpy.typing as npt

Comment on lines +1 to +3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numpy imports should not be at the top of the file. The order should be -

  1. module-level comment/docstring
  2. from __future__ import annotations
  3. All other imports

This will fix the failing style check (which you can also run locally using instructions provided above).

#
# NumpyArray class
#
from __future__ import annotations
import numpy as np
from scipy.sparse import csr_matrix, issparse

import pybamm
Expand Down Expand Up @@ -38,7 +40,7 @@ class Array(pybamm.Symbol):

def __init__(
self,
entries: np.ndarray | list[float] | csr_matrix,
entries: npt.NDArray | list[float] | csr_matrix,
name: str | None = None,
domain: DomainType = None,
auxiliary_domains: AuxiliaryDomainType = None,
Expand Down Expand Up @@ -144,8 +146,8 @@ def create_copy(
def _base_evaluate(
self,
t: float | None = None,
y: np.ndarray | None = None,
y_dot: np.ndarray | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
inputs: dict | str | None = None,
):
"""See :meth:`pybamm.Symbol._base_evaluate()`."""
Expand All @@ -165,7 +167,7 @@ def to_json(self):
Method to serialise an Array object into JSON.
"""

if isinstance(self.entries, np.ndarray):
if isinstance(self.entries, npt.NDArray):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

matrix = self.entries.tolist()
elif isinstance(self.entries, csr_matrix):
matrix = {
Expand Down
14 changes: 8 additions & 6 deletions src/pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import numpy as np
import numpy.typing as npt

#
# Binary operator classes
#
from __future__ import annotations
import numbers

import numpy as np
import sympy
from scipy.sparse import csr_matrix, issparse
import functools
Expand All @@ -22,13 +24,13 @@ def _preprocess_binary(
) -> tuple[pybamm.Symbol, pybamm.Symbol]:
if isinstance(left, (float, int, np.number)):
left = pybamm.Scalar(left)
elif isinstance(left, np.ndarray):
elif isinstance(left, npt.NDArray):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For all isinstance calls basically

if left.ndim > 1:
raise ValueError("left must be a 1D array")
left = pybamm.Vector(left)
if isinstance(right, (float, int, np.number)):
right = pybamm.Scalar(right)
elif isinstance(right, np.ndarray):
elif isinstance(right, npt.NDArray):
if right.ndim > 1:
raise ValueError("right must be a 1D array")
right = pybamm.Vector(right)
Expand Down Expand Up @@ -152,8 +154,8 @@ def _binary_new_copy(self, left: ChildSymbol, right: ChildSymbol):
def evaluate(
self,
t: float | None = None,
y: np.ndarray | None = None,
y_dot: np.ndarray | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
inputs: dict | str | None = None,
):
"""See :meth:`pybamm.Symbol.evaluate()`."""
Expand Down Expand Up @@ -558,7 +560,7 @@ def _binary_jac(self, left_jac, right_jac):
def _binary_evaluate(self, left, right):
"""See :meth:`pybamm.BinaryOperator._binary_evaluate()`."""
# numpy 1.25 deprecation warning: extract value from numpy arrays
if isinstance(right, np.ndarray):
if isinstance(right, npt.NDArray):
return int(left == right.item())
else:
return int(left == right)
Expand Down
12 changes: 7 additions & 5 deletions src/pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import numpy as np
import numpy.typing as npt

#
# Concatenation classes
#
from __future__ import annotations
import copy
from collections import defaultdict

import numpy as np
import sympy
from scipy.sparse import issparse, vstack
from collections.abc import Sequence
Expand Down Expand Up @@ -111,7 +113,7 @@ def get_children_domains(self, children: Sequence[pybamm.Symbol]):

return domains

def _concatenation_evaluate(self, children_eval: list[np.ndarray]):
def _concatenation_evaluate(self, children_eval: list[npt.NDArray]):
"""See :meth:`Concatenation._concatenation_evaluate()`."""
if len(children_eval) == 0:
return np.array([])
Expand All @@ -121,8 +123,8 @@ def _concatenation_evaluate(self, children_eval: list[np.ndarray]):
def evaluate(
self,
t: float | None = None,
y: np.ndarray | None = None,
y_dot: np.ndarray | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
inputs: dict | str | None = None,
):
"""See :meth:`pybamm.Symbol.evaluate()`."""
Expand Down Expand Up @@ -366,7 +368,7 @@ def create_slices(self, node: pybamm.Symbol) -> defaultdict:
start = end
return slices

def _concatenation_evaluate(self, children_eval: list[np.ndarray]):
def _concatenation_evaluate(self, children_eval: list[npt.NDArray]):
"""See :meth:`Concatenation._concatenation_evaluate()`."""
# preallocate vector
vector = np.empty((self._size, 1))
Expand Down
4 changes: 2 additions & 2 deletions src/pybamm/expression_tree/discrete_time_sum.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy.typing as npt
import pybamm
import numpy as np


class DiscreteTimeData(pybamm.Interpolant):
Expand All @@ -19,7 +19,7 @@ class DiscreteTimeData(pybamm.Interpolant):

"""

def __init__(self, time_points: np.ndarray, data: np.ndarray, name: str):
def __init__(self, time_points: npt.NDArray, data: npt.NDArray, name: str):
super().__init__(time_points, data, pybamm.t, name)

def create_copy(self, new_children=None, perform_simplifications=True):
Expand Down
8 changes: 5 additions & 3 deletions src/pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import numpy as np
import numpy.typing as npt

#
# Function classes and methods
#
from __future__ import annotations

import numpy as np
from scipy import special
import sympy
from typing import Callable
Expand Down Expand Up @@ -122,8 +124,8 @@ def _function_jac(self, children_jacs):
def evaluate(
self,
t: float | None = None,
y: np.ndarray | None = None,
y_dot: np.ndarray | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
inputs: dict | str | None = None,
):
"""See :meth:`pybamm.Symbol.evaluate()`."""
Expand Down
7 changes: 4 additions & 3 deletions src/pybamm/expression_tree/independent_variable.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import numpy.typing as npt

#
# IndependentVariable class
#
from __future__ import annotations
import sympy
import numpy as np

import pybamm
from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType
Expand Down Expand Up @@ -94,8 +95,8 @@ def create_copy(
def _base_evaluate(
self,
t: float | None = None,
y: np.ndarray | None = None,
y_dot: np.ndarray | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
inputs: dict | str | None = None,
):
"""See :meth:`pybamm.Symbol._base_evaluate()`."""
Expand Down
8 changes: 5 additions & 3 deletions src/pybamm/expression_tree/input_parameter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import numpy as np
import numpy.typing as npt

#
# Parameter classes
#
from __future__ import annotations
import numbers
import numpy as np
import scipy.sparse
import pybamm

Expand Down Expand Up @@ -88,8 +90,8 @@ def _jac(self, variable: pybamm.StateVector) -> pybamm.Matrix:
def _base_evaluate(
self,
t: float | None = None,
y: np.ndarray | None = None,
y_dot: np.ndarray | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
inputs: dict | str | None = None,
):
# inputs should be a dictionary
Expand Down
12 changes: 7 additions & 5 deletions src/pybamm/expression_tree/interpolant.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
import numpy.typing as npt

#
# Interpolating class
#
from __future__ import annotations
import numpy as np
from scipy import interpolate
from collections.abc import Sequence
import numbers
Expand Down Expand Up @@ -43,8 +45,8 @@ class Interpolant(pybamm.Function):

def __init__(
self,
x: np.ndarray | Sequence[np.ndarray],
y: np.ndarray,
x: npt.NDArray | Sequence[npt.NDArray],
y: npt.NDArray,
children: Sequence[pybamm.Symbol] | pybamm.Time,
name: str | None = None,
interpolator: str | None = "linear",
Expand Down Expand Up @@ -96,7 +98,7 @@ def __init__(
x1 = x[0]
else:
x1 = x
x: list[np.ndarray] = [x] # type: ignore[no-redef]
x: list[npt.NDArray] = [x] # type: ignore[no-redef]
x2 = None
if x1.shape[0] != y.shape[0]:
raise ValueError(
Expand Down Expand Up @@ -269,7 +271,7 @@ def create_copy(self, new_children=None, perform_simplifications=True):
def _function_evaluate(self, evaluated_children):
children_eval_flat = []
for child in evaluated_children:
if isinstance(child, np.ndarray):
if isinstance(child, npt.NDArray):
children_eval_flat.append(child.flatten())
else:
children_eval_flat.append(child)
Expand Down
6 changes: 4 additions & 2 deletions src/pybamm/expression_tree/matrix.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
import numpy.typing as npt

#
# Matrix class
#
from __future__ import annotations
import numpy as np
from scipy.sparse import csr_matrix, issparse

import pybamm
Expand All @@ -16,7 +18,7 @@ class Matrix(pybamm.Array):

def __init__(
self,
entries: np.ndarray | list[float] | csr_matrix,
entries: npt.NDArray | list[float] | csr_matrix,
name: str | None = None,
domain: DomainType = None,
auxiliary_domains: AuxiliaryDomainType = None,
Expand Down
Loading
Loading