Skip to content

Commit

Permalink
Rename have_optional_dependency (pybamm-team#3866)
Browse files Browse the repository at this point in the history
* Rename have_optional_dependency

* Change log

* Fix import

* style: pre-commit fixes

* Update pybamm/util.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Arjun Verma <[email protected]>
  • Loading branch information
3 people authored and lorenzofavaro committed Mar 13, 2024
1 parent 4a5c6eb commit 885fe90
Show file tree
Hide file tree
Showing 18 changed files with 60 additions and 62 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

## Breaking changes

- Renamed "have_optional_dependency" to "import_optional_dependency" ([#3866](https://github.com/pybamm-team/PyBaMM/pull/3866))
- Integrated the `[latexify]` extra into the core PyBaMM package, deprecating the `pybamm[latexify]` set of optional dependencies. SymPy is now a required dependency and will be installed upon installing PyBaMM ([#3848](https://github.com/pybamm-team/PyBaMM/pull/3848))
- Renamed "testing" argument for plots to "show_plot" and flipped its meaning (show_plot=True is now the default and shows the plot) ([#3842](https://github.com/pybamm-team/PyBaMM/pull/3842))
- Dropped support for BPX version 0.3.0 and below ([#3414](https://github.com/pybamm-team/PyBaMM/pull/3414))
Expand Down
8 changes: 4 additions & 4 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,21 +104,21 @@ Only 'core pybamm' is installed by default. The others have to be specified expl

PyBaMM utilizes optional dependencies to allow users to choose which additional libraries they want to use. Managing these optional dependencies and their imports is essential to provide flexibility to PyBaMM users.

PyBaMM provides a utility function `have_optional_dependency`, to check for the availability of optional dependencies within methods. This function can be used to conditionally import optional dependencies only if they are available. Here's how to use it:
PyBaMM provides a utility function `import_optional_dependency`, to check for the availability of optional dependencies within methods. This function can be used to conditionally import optional dependencies only if they are available. Here's how to use it:

Optional dependencies should never be imported at the module level, but always inside methods. For example:

```
def use_pybtex(x,y,z):
pybtex = have_optional_dependency("pybtex")
pybtex = import_optional_dependency("pybtex")
...
```

While importing a specific module instead of an entire package/library:

```python
def use_parse_file(x, y, z):
parse_file = have_optional_dependency("pybtex.database", "parse_file")
parse_file = import_optional_dependency("pybtex.database", "parse_file")
...
```

Expand All @@ -143,7 +143,7 @@ class TestUtil(TestCase):
pybamm.function_using_pybtex(x, y, z)

# Test that the function works when pybtex is available
sys.modules["pybtex"] = pybamm.util.have_optional_dependency("pybtex")
sys.modules["pybtex"] = pybamm.util.import_optional_dependency("pybtex")
pybamm.function_using_pybtex(x, y, z)
```

Expand Down
2 changes: 1 addition & 1 deletion pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
get_parameters_filepath,
have_jax,
install_jax,
have_optional_dependency,
import_optional_dependency,
is_jax_compatible,
get_git_commit_info,
)
Expand Down
12 changes: 6 additions & 6 deletions pybamm/citations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import warnings
from sys import _getframe
from pybamm.util import have_optional_dependency
from pybamm.util import import_optional_dependency


class Citations:
Expand Down Expand Up @@ -74,7 +74,7 @@ def read_citations(self):
"""Reads the citations in `pybamm.CITATIONS.bib`. Other works can be cited
by passing a BibTeX citation to :meth:`register`.
"""
parse_file = have_optional_dependency("pybtex.database", "parse_file")
parse_file = import_optional_dependency("pybtex.database", "parse_file")
citations_file = os.path.join(pybamm.root_dir(), "pybamm", "CITATIONS.bib")
bib_data = parse_file(citations_file, bib_format="bibtex")
for key, entry in bib_data.entries.items():
Expand All @@ -85,7 +85,7 @@ def _add_citation(self, key, entry):
previous entry is overwritten
"""

Entry = have_optional_dependency("pybtex.database", "Entry")
Entry = import_optional_dependency("pybtex.database", "Entry")
# Check input types are correct
if not isinstance(key, str) or not isinstance(entry, Entry):
raise TypeError()
Expand Down Expand Up @@ -151,8 +151,8 @@ def _parse_citation(self, key):
key: str
A BibTeX formatted citation
"""
PybtexError = have_optional_dependency("pybtex.scanner", "PybtexError")
parse_string = have_optional_dependency("pybtex.database", "parse_string")
PybtexError = import_optional_dependency("pybtex.scanner", "PybtexError")
parse_string = import_optional_dependency("pybtex.database", "parse_string")
try:
# Parse string as a bibtex citation, and check that a citation was found
bib_data = parse_string(key, bib_format="bibtex")
Expand Down Expand Up @@ -219,7 +219,7 @@ def print(self, filename=None, output_format="text", verbose=False):
"""
# Parse citations that were not known keys at registration, but do not
# fail if they cannot be parsed
pybtex = have_optional_dependency("pybtex")
pybtex = import_optional_dependency("pybtex")
try:
for key in self._unknown_citations:
self._parse_citation(key)
Expand Down
4 changes: 2 additions & 2 deletions pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing_extensions import TypeVar

import pybamm
from pybamm.util import have_optional_dependency
from pybamm.util import import_optional_dependency


class Function(pybamm.Symbol):
Expand Down Expand Up @@ -98,7 +98,7 @@ def _function_diff(self, children: Sequence[pybamm.Symbol], idx: float):
Derivative with respect to child number 'idx'.
See :meth:`pybamm.Symbol._diff()`.
"""
autograd = have_optional_dependency("autograd")
autograd = import_optional_dependency("autograd")
# Store differentiated function, needed in case we want to convert to CasADi
if self.derivative == "autograd":
return Function(
Expand Down
10 changes: 5 additions & 5 deletions pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import TYPE_CHECKING, Sequence, cast

import pybamm
from pybamm.util import have_optional_dependency
from pybamm.util import import_optional_dependency
from pybamm.expression_tree.printing.print_name import prettify_print_name

if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -479,7 +479,7 @@ def render(self): # pragma: no cover
"""
Print out a visual representation of the tree (this node and its children)
"""
anytree = have_optional_dependency("anytree")
anytree = import_optional_dependency("anytree")
for pre, _, node in anytree.RenderTree(self):
if isinstance(node, pybamm.Scalar) and node.name != str(node.value):
print(f"{pre}{node.name} = {node.value}")
Expand All @@ -498,7 +498,7 @@ def visualise(self, filename: str):
filename to output, must end in ".png"
"""

DotExporter = have_optional_dependency("anytree.exporter", "DotExporter")
DotExporter = import_optional_dependency("anytree.exporter", "DotExporter")
# check that filename ends in .png.
if filename[-4:] != ".png":
raise ValueError("filename should end in .png")
Expand All @@ -518,7 +518,7 @@ def relabel_tree(self, symbol: Symbol, counter: int):
Finds all children of a symbol and assigns them a new id so that they can be
visualised properly using the graphviz output
"""
anytree = have_optional_dependency("anytree")
anytree = import_optional_dependency("anytree")
name = symbol.name
if name == "div":
name = "&nabla;&sdot;"
Expand Down Expand Up @@ -561,7 +561,7 @@ def pre_order(self):
a
b
"""
anytree = have_optional_dependency("anytree")
anytree = import_optional_dependency("anytree")
return anytree.PreOrderIter(self)

def __str__(self):
Expand Down
8 changes: 5 additions & 3 deletions pybamm/expression_tree/unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from scipy.sparse import csr_matrix, issparse
import sympy
import pybamm
from pybamm.util import have_optional_dependency
from pybamm.util import import_optional_dependency
from pybamm.type_definitions import DomainsType


Expand Down Expand Up @@ -450,7 +450,9 @@ def _unary_new_copy(self, child):

def _sympy_operator(self, child):
"""Override :meth:`pybamm.UnaryOperator._sympy_operator`"""
sympy_Gradient = have_optional_dependency("sympy.vector.operators", "Gradient")
sympy_Gradient = import_optional_dependency(
"sympy.vector.operators", "Gradient"
)
return sympy_Gradient(child)


Expand Down Expand Up @@ -484,7 +486,7 @@ def _unary_new_copy(self, child):

def _sympy_operator(self, child):
"""Override :meth:`pybamm.UnaryOperator._sympy_operator`"""
sympy_Divergence = have_optional_dependency(
sympy_Divergence = import_optional_dependency(
"sympy.vector.operators", "Divergence"
)
return sympy_Divergence(child)
Expand Down
4 changes: 2 additions & 2 deletions pybamm/meshes/scikit_fem_submeshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .meshes import SubMesh
import numpy as np

from pybamm.util import have_optional_dependency
from pybamm.util import import_optional_dependency


class ScikitSubMesh2D(SubMesh):
Expand All @@ -27,7 +27,7 @@ class ScikitSubMesh2D(SubMesh):
"""

def __init__(self, edges, coord_sys, tabs):
skfem = have_optional_dependency("skfem")
skfem = import_optional_dependency("skfem")
self.edges = edges
self.nodes = dict.fromkeys(["y", "z"])
for var in self.nodes.keys():
Expand Down
4 changes: 1 addition & 3 deletions pybamm/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import pybamm
from pybamm.expression_tree.operations.serialise import Serialise
import sympy


class BaseModel:
Expand Down Expand Up @@ -1185,8 +1184,7 @@ def latexify(self, filename=None, newline=True, output_variables=None):
This will return first five model equations
>>> model.latexify(newline=False)[1:5]
"""
if sympy:
from pybamm.expression_tree.operations.latexify import Latexify
from pybamm.expression_tree.operations.latexify import Latexify

return Latexify(self, filename, newline).latexify(
output_variables=output_variables
Expand Down
4 changes: 2 additions & 2 deletions pybamm/plotting/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
import pybamm
from .quick_plot import ax_min, ax_max
from pybamm.util import have_optional_dependency
from pybamm.util import import_optional_dependency


def plot(x, y, ax=None, show_plot=True, **kwargs):
Expand All @@ -27,7 +27,7 @@ def plot(x, y, ax=None, show_plot=True, **kwargs):
Keyword arguments, passed to plt.plot
"""
plt = have_optional_dependency("matplotlib.pyplot")
plt = import_optional_dependency("matplotlib.pyplot")

if not isinstance(x, pybamm.Array):
raise TypeError("x must be 'pybamm.Array'")
Expand Down
4 changes: 2 additions & 2 deletions pybamm/plotting/plot2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
import pybamm
from .quick_plot import ax_min, ax_max
from pybamm.util import have_optional_dependency
from pybamm.util import import_optional_dependency


def plot2D(x, y, z, ax=None, show_plot=True, **kwargs):
Expand All @@ -27,7 +27,7 @@ def plot2D(x, y, z, ax=None, show_plot=True, **kwargs):
only display the plot after plt.show() has been called.
"""
plt = have_optional_dependency("matplotlib.pyplot")
plt = import_optional_dependency("matplotlib.pyplot")

if not isinstance(x, pybamm.Array):
raise TypeError("x must be 'pybamm.Array'")
Expand Down
4 changes: 2 additions & 2 deletions pybamm/plotting/plot_summary_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
import numpy as np
import pybamm
from pybamm.util import have_optional_dependency
from pybamm.util import import_optional_dependency


def plot_summary_variables(
Expand All @@ -27,7 +27,7 @@ def plot_summary_variables(
Keyword arguments, passed to plt.subplots.
"""
plt = have_optional_dependency("matplotlib.pyplot")
plt = import_optional_dependency("matplotlib.pyplot")

if isinstance(solutions, pybamm.Solution):
solutions = [solutions]
Expand Down
4 changes: 2 additions & 2 deletions pybamm/plotting/plot_voltage_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
import numpy as np

from pybamm.util import have_optional_dependency
from pybamm.util import import_optional_dependency
from pybamm.simulation import Simulation
from pybamm.solvers.solution import Solution

Expand Down Expand Up @@ -42,7 +42,7 @@ def plot_voltage_components(
solution = input_data.solution
elif isinstance(input_data, Solution):
solution = input_data
plt = have_optional_dependency("matplotlib.pyplot")
plt = import_optional_dependency("matplotlib.pyplot")

# Set a default value for alpha, the opacity
kwargs_fill = {"alpha": 0.6, **kwargs_fill}
Expand Down
20 changes: 10 additions & 10 deletions pybamm/plotting/quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pybamm
from collections import defaultdict
from pybamm.util import have_optional_dependency
from pybamm.util import import_optional_dependency


class LoopList(list):
Expand Down Expand Up @@ -46,7 +46,7 @@ def split_long_string(title, max_words=None):

def close_plots():
"""Close all open figures"""
plt = have_optional_dependency("matplotlib.pyplot")
plt = import_optional_dependency("matplotlib.pyplot")

plt.close("all")

Expand Down Expand Up @@ -473,10 +473,10 @@ def plot(self, t, dynamic=False):
Dimensional time (in 'time_units') at which to plot.
"""

plt = have_optional_dependency("matplotlib.pyplot")
gridspec = have_optional_dependency("matplotlib.gridspec")
cm = have_optional_dependency("matplotlib", "cm")
colors = have_optional_dependency("matplotlib", "colors")
plt = import_optional_dependency("matplotlib.pyplot")
gridspec = import_optional_dependency("matplotlib.gridspec")
cm = import_optional_dependency("matplotlib", "cm")
colors = import_optional_dependency("matplotlib", "colors")

t_in_seconds = t * self.time_scaling_factor
self.fig = plt.figure(figsize=self.figsize)
Expand Down Expand Up @@ -674,8 +674,8 @@ def dynamic_plot(self, show_plot=True, step=None):
continuous_update=False,
)
else:
plt = have_optional_dependency("matplotlib.pyplot")
Slider = have_optional_dependency("matplotlib.widgets", "Slider")
plt = import_optional_dependency("matplotlib.pyplot")
Slider = import_optional_dependency("matplotlib.widgets", "Slider")

# create an initial plot at time self.min_t
self.plot(self.min_t, dynamic=True)
Expand Down Expand Up @@ -779,8 +779,8 @@ def create_gif(self, number_of_images=80, duration=0.1, output_filename="plot.gi
Name of the generated GIF file.
"""
imageio = have_optional_dependency("imageio.v2")
plt = have_optional_dependency("matplotlib.pyplot")
imageio = import_optional_dependency("imageio.v2")
plt = import_optional_dependency("matplotlib.pyplot")

# time stamps at which the images/plots will be created
time_array = np.linspace(self.min_t, self.max_t, num=number_of_images)
Expand Down
4 changes: 2 additions & 2 deletions pybamm/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import sys
from functools import lru_cache
from datetime import timedelta
from pybamm.util import have_optional_dependency
from pybamm.util import import_optional_dependency

from pybamm.expression_tree.operations.serialise import Serialise

Expand Down Expand Up @@ -701,7 +701,7 @@ def solve(

# check if a user has tqdm installed
if showprogress:
tqdm = have_optional_dependency("tqdm")
tqdm = import_optional_dependency("tqdm")
cycle_lengths = tqdm.tqdm(
self.experiment.cycle_lengths,
desc="Cycling",
Expand Down
Loading

0 comments on commit 885fe90

Please sign in to comment.