Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve default imports for optional dependencies #3475

Merged
merged 54 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
4c237c1
prevent `pybtex` default installation
arjxn-py Oct 26, 2023
6d30b3a
resolve `anytree` default installation
arjxn-py Oct 26, 2023
e3b3b35
resolve `autograd` default imports
arjxn-py Oct 26, 2023
4dd2317
resolve `skfem` default imports
arjxn-py Oct 26, 2023
9e24562
resolve `tqdm` default imports
arjxn-py Oct 26, 2023
50315e7
Raise import error for `anytree` requiring functions
arjxn-py Oct 28, 2023
e09fcea
Make simple function to check optional dependency
arjxn-py Oct 29, 2023
a07b342
Make decorater function
arjxn-py Oct 30, 2023
9d9db2b
Make normal reusable function for optional deps
arjxn-py Oct 30, 2023
34311ee
Update `citations.py` for `pybtex` as optional dependency
arjxn-py Oct 30, 2023
1218065
Execute silently, raise ImportError & import function correctly
arjxn-py Oct 30, 2023
90ac2ee
Update `Symbol` for `anytree` as optional dependency
arjxn-py Oct 30, 2023
3e68617
Update `simulation` for `tqdm` as optional dependency
arjxn-py Nov 1, 2023
5551dac
Update `Function` class for `autograd` as optional dependency
arjxn-py Nov 1, 2023
64d9037
Resolve `scikit-fem` based methods
arjxn-py Nov 1, 2023
9ee911b
Resolve `sympy` based methods
arjxn-py Nov 1, 2023
efe8877
Fix Typo
arjxn-py Nov 1, 2023
9111055
Return more helpful message
arjxn-py Nov 1, 2023
c65a2a2
Abstraction to only show module name if not available
arjxn-py Nov 3, 2023
4d32e32
Update docs for have_optional_deps
arjxn-py Nov 3, 2023
e38205e
Merge branch 'develop' into fix-default-imports
arjxn-py Nov 3, 2023
fd09163
Update for `have_optional_dependency`
arjxn-py Nov 6, 2023
926f8d7
Add comments to `have_optional_dependency`
arjxn-py Nov 6, 2023
ec963d1
Add `test_have_optional_dependency`
arjxn-py Nov 8, 2023
dd8a6f2
Apply suggestions from code review
arjxn-py Nov 9, 2023
aa2327e
style: pre-commit fixes
pre-commit-ci[bot] Nov 9, 2023
c28c7fb
Raise simple ModuleNotFoundError even if attribute not found
arjxn-py Nov 9, 2023
fd9ae61
Set pybtex to None to avoid import
arjxn-py Nov 9, 2023
7cb2ef6
Add more testcases for optional dependencies
arjxn-py Nov 9, 2023
b681bbc
Add test for case if dependency is available
arjxn-py Nov 9, 2023
f2e37cf
Reset pybtex to run dependent function
arjxn-py Nov 9, 2023
8d6db99
Add test for full coverage
arjxn-py Nov 9, 2023
700ab5a
Declare `anytree` onn top to pass `test_is_constant_and_can_evaluate`
arjxn-py Nov 9, 2023
bfbe41e
Apply suggestions from code review
arjxn-py Nov 10, 2023
2f1d3ce
Shorten assert string
arjxn-py Nov 10, 2023
fe6b910
Improve readibility & add case to fix coverage
arjxn-py Nov 10, 2023
6239653
Modify CONTRIBUTING.md for optional dependency tests
arjxn-py Nov 10, 2023
6f5823f
Prevent inheriting LatexPrinter instead use a function
arjxn-py Nov 11, 2023
c093d44
Remove redundant testcase
arjxn-py Nov 11, 2023
a5d2573
Add `anytree` to required & install `[plot,cite]` in `examples` session
arjxn-py Nov 13, 2023
a3952dd
Set iterator based upon `tqdm`
arjxn-py Nov 13, 2023
ae22805
Clean up tqdm mess
Saransh-cpp Nov 13, 2023
b5f74ad
Fix matplotlib errors
Saransh-cpp Nov 13, 2023
78792bc
Apply suggestions from code review
Saransh-cpp Nov 13, 2023
a8ac4c7
Remove test for tqdm as ModuleNotFoundError no longer being raised fo…
arjxn-py Nov 13, 2023
bd2d009
Fix sympy overrides
Saransh-cpp Nov 13, 2023
0b32653
Resolve conflicts
Saransh-cpp Nov 13, 2023
9d342c0
fix tabs
Saransh-cpp Nov 13, 2023
2e30131
Fix CustomPrinter
Saransh-cpp Nov 13, 2023
4711362
Fix test
Saransh-cpp Nov 13, 2023
05c74e4
Add `anytree` to required deps in docs
arjxn-py Nov 13, 2023
4d118ab
Fix CHANGELOG
Saransh-cpp Nov 14, 2023
2213407
Merge branch 'fix-default-imports' of https://github.com/arjxn-py/PyB…
Saransh-cpp Nov 14, 2023
d685c38
Merge branch 'develop' into fix-default-imports
agriyakhetarpal Nov 14, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Bug fixes

- Fixed a bug where the JaxSolver would fails when using GPU support with no input parameters ([#3423](https://github.com/pybamm-team/PyBaMM/pull/3423))
- Fixed bug in calculation of theoretical energy that made it very slow ([#3506](https://github.com/pybamm-team/PyBaMM/pull/3506))

# [v23.9rc0](https://github.com/pybamm-team/PyBaMM/tree/v23.9rc0) - 2023-10-31

Expand All @@ -23,6 +23,7 @@

## Bug fixes

- Fixed a bug where the JaxSolver would fail when using GPU support with no input parameters ([#3423](https://github.com/pybamm-team/PyBaMM/pull/3423))
- Fixed a bug where empty lists passed to QuickPlot resulted in an IndexError and did not return a meaningful error message ([#3359](https://github.com/pybamm-team/PyBaMM/pull/3359))
- Fixed a bug where there was a missing thermal conductivity in the thermal pouch cell models ([#3330](https://github.com/pybamm-team/PyBaMM/pull/3330))
- Fixed a bug that caused incorrect results of “{Domain} electrode thickness change [m]” due to the absence of dimension for the variable `electrode_thickness_change`([#3329](https://github.com/pybamm-team/PyBaMM/pull/3329)).
Expand Down Expand Up @@ -61,7 +62,7 @@
- Added option to use an empirical hysteresis model for the diffusivity and exchange-current density ([#3194](https://github.com/pybamm-team/PyBaMM/pull/3194))
- Double-layer capacity can now be provided as a function of temperature ([#3174](https://github.com/pybamm-team/PyBaMM/pull/3174))
- `pybamm_install_jax` is deprecated. It is now replaced with `pip install pybamm[jax]` ([#3163](https://github.com/pybamm-team/PyBaMM/pull/3163))
- PyBaMM now has optional dependencies that can be installed with the pattern `pip install pybamm[option]` e.g. `pybamm[plot]` ([#3044](https://github.com/pybamm-team/PyBaMM/pull/3044))
- PyBaMM now has optional dependencies that can be installed with the pattern `pip install pybamm[option]` e.g. `pybamm[plot]` ([#3044](https://github.com/pybamm-team/PyBaMM/pull/3044), [#3475](https://github.com/pybamm-team/PyBaMM/pull/3475))

# [v23.5](https://github.com/pybamm-team/PyBaMM/tree/v23.5) - 2023-06-18

Expand Down
50 changes: 38 additions & 12 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,52 @@ On the other hand... We _do_ want to compare several tools, to generate document

Only 'core pybamm' is installed by default. The others have to be specified explicitly when running the installation command.

### Matplotlib
### Managing Optional Dependencies and Their Imports

We use Matplotlib in PyBaMM, but with two caveats:
PyBaMM utilizes optional dependencies to allow users to choose which additional libraries they want to use. Managing these optional dependencies and their imports is essential to provide flexibility to PyBaMM users.

First, Matplotlib should only be used in plotting methods, and these should _never_ be called by other PyBaMM methods. So users who don't like Matplotlib will not be forced to use it in any way. Use in notebooks is OK and encouraged.
PyBaMM provides a utility function `have_optional_dependency`, to check for the availability of optional dependencies within methods. This function can be used to conditionally import optional dependencies only if they are available. Here's how to use it:

Second, Matplotlib should never be imported at the module level, but always inside methods. For example:
Optional dependencies should never be imported at the module level, but always inside methods. For example:

```
def plot_great_things(self, x, y, z):
import matplotlib.pyplot as pl
def use_pybtex(x,y,z):
pybtex = have_optional_dependency("pybtex")
...
```

This allows people to (1) use PyBaMM without ever importing Matplotlib and (2) configure Matplotlib's back-end in their scripts, which _must_ be done before e.g. `pyplot` is first imported.
While importing a specific module instead of an entire package/library:

```python
def use_parse_file(x, y, z):
parse_file = have_optional_dependency("pybtex.database", "parse_file")
...
```

This allows people to (1) use PyBaMM without importing optional dependencies by default and (2) configure module-dependent functionalities in their scripts, which _must_ be done before e.g. `print_citations` method is first imported.

**Writing Tests for Optional Dependencies**

Whenever a new optional dependency is added for optional functionality, it is recommended to write a corresponding unit test in `test_util.py`. This ensures that an error is raised upon the absence of said dependency. Here's an example:

```python
from tests import TestCase
import pybamm


class TestUtil(TestCase):
def test_optional_dependency(self):
# Test that an error is raised when pybtex is not available
with self.assertRaisesRegex(
ModuleNotFoundError, "Optional dependency pybtex is not available"
):
sys.modules["pybtex"] = None
pybamm.function_using_pybtex(x, y, z)

# Test that the function works when pybtex is available
sys.modules["pybtex"] = pybamm.util.have_optional_dependency("pybtex")
pybamm.function_using_pybtex(x, y, z)
```

## Testing

Expand Down Expand Up @@ -266,7 +297,6 @@ This also means that, if you can't fix the bug yourself, it will be much easier
```

This will start the debugger at the point where the `ValueError` was raised, and allow you to investigate further. Sometimes, it is more informative to put the try-except block further up the call stack than exactly where the error is raised.

2. Warnings. If functions are raising warnings instead of errors, it can be hard to pinpoint where this is coming from. Here, you can use the `warnings` module to convert warnings to errors:

```python
Expand All @@ -276,19 +306,15 @@ This also means that, if you can't fix the bug yourself, it will be much easier
```

Then you can use a try-except block, as in a., but with, for example, `RuntimeWarning` instead of `ValueError`.

3. Stepping through the expression tree. Most calls in PyBaMM are operations on [expression trees](https://github.com/pybamm-team/PyBaMM/blob/develop/docs/source/examples/notebooks/expression_tree/expression-tree.ipynb). To view an expression tree in ipython, you can use the `render` command:

```python
expression_tree.render()
```

You can then step through the expression tree, using the `children` attribute, to pinpoint exactly where a bug is coming from. For example, if `expression_tree.jac(y)` is failing, you can check `expression_tree.children[0].jac(y)`, then `expression_tree.children[0].children[0].jac(y)`, etc.

3. To isolate whether a bug is in a model, its Jacobian or its simplified version, you can set the `use_jacobian` and/or `use_simplify` attributes of the model to `False` (they are both `True` by default for most models).

4. If a model isn't giving the answer you expect, you can try comparing it to other models. For example, you can investigate parameter limits in which two models should give the same answer by setting some parameters to be small or zero. The `StandardOutputComparison` class can be used to compare some standard outputs from battery models.

5. To get more information about what is going on under the hood, and hence understand what is causing the bug, you can set the [logging](https://realpython.com/python-logging/) level to `DEBUG` by adding the following line to your test or script:

```python3
Expand Down
1 change: 1 addition & 0 deletions docs/source/user_guide/installation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Package Minimum support
`SciPy <https://docs.scipy.org/doc/scipy/>`__ 2.8.2
`CasADi <https://web.casadi.org/docs/>`__ 3.6.0
`Xarray <https://docs.xarray.dev/en/stable/>`__ 2023.04.0
`Anytree <https://anytree.readthedocs.io/en/stable/>`__ 2.4.3
================================================================ ==========================

.. _install.optional_dependencies:
Expand Down
2 changes: 1 addition & 1 deletion pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@
get_parameters_filepath,
have_jax,
install_jax,
have_optional_dependency,
is_jax_compatible,
get_git_commit_info,
)
from .logger import logger, set_logging_level, get_new_logger
from .settings import settings
from .citations import Citations, citations, print_citations

#
# Classes for the Expression Tree
#
Expand Down
9 changes: 6 additions & 3 deletions pybamm/citations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
import pybamm
import os
import warnings
import pybtex
from sys import _getframe
from pybtex.database import parse_file, parse_string, Entry
from pybtex.scanner import PybtexError
from pybamm.util import have_optional_dependency


class Citations:
Expand Down Expand Up @@ -76,6 +74,7 @@ def read_citations(self):
"""Reads the citations in `pybamm.CITATIONS.bib`. Other works can be cited
by passing a BibTeX citation to :meth:`register`.
"""
parse_file = have_optional_dependency("pybtex.database", "parse_file")
citations_file = os.path.join(pybamm.root_dir(), "pybamm", "CITATIONS.bib")
bib_data = parse_file(citations_file, bib_format="bibtex")
for key, entry in bib_data.entries.items():
Expand All @@ -86,6 +85,7 @@ def _add_citation(self, key, entry):
previous entry is overwritten
"""

Entry = have_optional_dependency("pybtex.database", "Entry")
# Check input types are correct
if not isinstance(key, str) or not isinstance(entry, Entry):
raise TypeError()
Expand Down Expand Up @@ -151,6 +151,8 @@ def _parse_citation(self, key):
key: str
A BibTeX formatted citation
"""
PybtexError = have_optional_dependency("pybtex.scanner", "PybtexError")
parse_string = have_optional_dependency("pybtex.database", "parse_string")
try:
# Parse string as a bibtex citation, and check that a citation was found
bib_data = parse_string(key, bib_format="bibtex")
Expand Down Expand Up @@ -217,6 +219,7 @@ def print(self, filename=None, output_format="text", verbose=False):
"""
# Parse citations that were not known keys at registration, but do not
# fail if they cannot be parsed
pybtex = have_optional_dependency("pybtex")
try:
for key in self._unknown_citations:
self._parse_citation(key)
Expand Down
3 changes: 2 additions & 1 deletion pybamm/expression_tree/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# NumpyArray class
#
import numpy as np
import sympy
from scipy.sparse import csr_matrix, issparse

import pybamm
from pybamm.util import have_optional_dependency


class Array(pybamm.Symbol):
Expand Down Expand Up @@ -125,6 +125,7 @@ def is_constant(self):

def to_equation(self):
"""Returns the value returned by the node when evaluated."""
sympy = have_optional_dependency("sympy")
entries_list = self.entries.tolist()
return sympy.Array(entries_list)

Expand Down
6 changes: 5 additions & 1 deletion pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import numbers

import numpy as np
import sympy
from scipy.sparse import csr_matrix, issparse
import functools

import pybamm
from pybamm.util import have_optional_dependency


def _preprocess_binary(left, right):
Expand Down Expand Up @@ -147,6 +147,7 @@ def _sympy_operator(self, left, right):

def to_equation(self):
"""Convert the node and its subtree into a SymPy equation."""
sympy = have_optional_dependency("sympy")
if self.print_name is not None:
return sympy.Symbol(self.print_name)
else:
Expand Down Expand Up @@ -323,6 +324,7 @@ def _binary_evaluate(self, left, right):

def _sympy_operator(self, left, right):
"""Override :meth:`pybamm.BinaryOperator._sympy_operator`"""
sympy = have_optional_dependency("sympy")
left = sympy.Matrix(left)
right = sympy.Matrix(right)
return left * right
Expand Down Expand Up @@ -626,6 +628,7 @@ def _binary_new_copy(self, left, right):

def _sympy_operator(self, left, right):
"""Override :meth:`pybamm.BinaryOperator._sympy_operator`"""
sympy = have_optional_dependency("sympy")
return sympy.Min(left, right)


Expand Down Expand Up @@ -662,6 +665,7 @@ def _binary_new_copy(self, left, right):

def _sympy_operator(self, left, right):
"""Override :meth:`pybamm.BinaryOperator._sympy_operator`"""
sympy = have_optional_dependency("sympy")
return sympy.Max(left, right)


Expand Down
3 changes: 2 additions & 1 deletion pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from collections import defaultdict

import numpy as np
import sympy
from scipy.sparse import issparse, vstack

import pybamm
from pybamm.util import have_optional_dependency


class Concatenation(pybamm.Symbol):
Expand Down Expand Up @@ -135,6 +135,7 @@ def is_constant(self):

def _sympy_operator(self, *children):
"""Apply appropriate SymPy operators."""
sympy = have_optional_dependency("sympy")
self.concat_latex = tuple(map(sympy.latex, children))

if self.print_name is not None:
Expand Down
9 changes: 6 additions & 3 deletions pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
#
import numbers

import autograd
import numpy as np
import sympy
from scipy import special

import pybamm

from pybamm.util import have_optional_dependency

class Function(pybamm.Symbol):
"""
Expand Down Expand Up @@ -96,6 +94,7 @@ def _function_diff(self, children, idx):
Derivative with respect to child number 'idx'.
See :meth:`pybamm.Symbol._diff()`.
"""
autograd = have_optional_dependency("autograd")
# Store differentiated function, needed in case we want to convert to CasADi
if self.derivative == "autograd":
return Function(
Expand Down Expand Up @@ -202,6 +201,7 @@ def _sympy_operator(self, child):

def to_equation(self):
"""Convert the node and its subtree into a SymPy equation."""
sympy = have_optional_dependency("sympy")
if self.print_name is not None:
return sympy.Symbol(self.print_name)
else:
Expand Down Expand Up @@ -250,6 +250,7 @@ def _function_new_copy(self, children):

def _sympy_operator(self, child):
"""Apply appropriate SymPy operators."""
sympy = have_optional_dependency("sympy")
class_name = self.__class__.__name__.lower()
sympy_function = getattr(sympy, class_name)
return sympy_function(child)
Expand All @@ -267,6 +268,7 @@ def _function_diff(self, children, idx):

def _sympy_operator(self, child):
"""Override :meth:`pybamm.Function._sympy_operator`"""
sympy = have_optional_dependency("sympy")
return sympy.asinh(child)


Expand All @@ -287,6 +289,7 @@ def _function_diff(self, children, idx):

def _sympy_operator(self, child):
"""Override :meth:`pybamm.Function._sympy_operator`"""
sympy = have_optional_dependency("sympy")
return sympy.atan(child)


Expand Down
5 changes: 3 additions & 2 deletions pybamm/expression_tree/independent_variable.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
#
# IndependentVariable class
#
import sympy

import pybamm
from pybamm.util import have_optional_dependency

KNOWN_COORD_SYS = ["cartesian", "cylindrical polar", "spherical polar"]

Expand Down Expand Up @@ -44,6 +43,7 @@ def _jac(self, variable):

def to_equation(self):
"""Convert the node and its subtree into a SymPy equation."""
sympy = have_optional_dependency("sympy")
if self.print_name is not None:
return sympy.Symbol(self.print_name)
else:
Expand Down Expand Up @@ -77,6 +77,7 @@ def _evaluate_for_shape(self):

def to_equation(self):
"""Convert the node and its subtree into a SymPy equation."""
sympy = have_optional_dependency("sympy")
return sympy.Symbol("t")


Expand Down
6 changes: 4 additions & 2 deletions pybamm/expression_tree/operations/latexify.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
import re
import warnings

import sympy

import pybamm
from pybamm.expression_tree.printing.sympy_overrides import custom_print_func
from pybamm.util import have_optional_dependency


def get_rng_min_max_name(rng, min_or_max):
Expand Down Expand Up @@ -88,6 +87,7 @@ def _get_bcs_displays(self, var):
Returns a list of boundary condition equations with ranges in front of
the equations.
"""
sympy = have_optional_dependency("sympy")
bcs_eqn_list = []
bcs = self.model.boundary_conditions.get(var, None)

Expand Down Expand Up @@ -118,6 +118,7 @@ def _get_bcs_displays(self, var):

def _get_param_var(self, node):
"""Returns a list of parameters and a list of variables."""
sympy = have_optional_dependency("sympy")
param_list = []
var_list = []
dfs_nodes = [node]
Expand Down Expand Up @@ -160,6 +161,7 @@ def _get_param_var(self, node):
return param_list, var_list

def latexify(self, output_variables=None):
sympy = have_optional_dependency("sympy")
# Voltage is the default output variable if it exists
if output_variables is None:
if "Voltage [V]" in self.model.variables:
Expand Down
Loading