Skip to content

Commit

Permalink
Merge pull request #1552 from pybamm-team/issue-1477-idaklu-send
Browse files Browse the repository at this point in the history
Issue 1477 sensitivities for solvers
  • Loading branch information
martinjrobins authored Aug 19, 2021
2 parents 6f4a152 + e898c65 commit 404fa2e
Show file tree
Hide file tree
Showing 30 changed files with 1,990 additions and 266 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Features

- `pybamm.base_solver.solve` function can take a list of input parameters to calculate the sensitivities of the solution with respect to. Alternatively, it can be set to `True` to calculate the sensitivities for all input parameters ([#1552](https://github.com/pybamm-team/PyBaMM/pull/1552))
- Addes UDDS and WLTC drive cycles ([#1601](https://github.com/pybamm-team/PyBaMM/pull/1601))
- Added capability for `quaternary` domains (in addition to `primary`, `secondary` and `tertiary`), increasing the maximum number of domains that a `Symbol` can have to 4. ([#1580](https://github.com/pybamm-team/PyBaMM/pull/1580))
- Tabs can now be placed at the bottom of the cell in 1+1D thermal models ([#1581](https://github.com/pybamm-team/PyBaMM/pull/1581))
Expand Down Expand Up @@ -42,6 +43,7 @@

## Breaking changes

- Changed sensitivity API. Removed `ProcessedSymbolicVariable`, all sensitivity now handled within the solvers and `ProcessedVariable` () ([#1552](https://github.com/pybamm-team/PyBaMM/pull/1552))
- The `Yang2017` parameter set has been removed as the complete parameter set is not publicly available in the literature ([#1577](https://github.com/pybamm-team/PyBaMM/pull/1577))
- Changed how options are specified for the "loss of active material" and "particle cracking" submodels. "loss of active material" can now be one of "none", "stress-driven", or "reaction-driven", or a 2-tuple for different options in negative and positive electrode. Similarly "particle cracking" (now called "particle mechanics") can now be "none", "swelling only", "swelling and cracking", or a 2-tuple ([#1490](https://github.com/pybamm-team/PyBaMM/pull/1490))
- Changed the variable in the full diffusion model from "Electrolyte concentration" to "Porosity times concentration" ([#1476](https://github.com/pybamm-team/PyBaMM/pull/1476))
Expand Down Expand Up @@ -189,7 +191,6 @@ This release adds new operators for more complex models, some basic sensitivity
(e.g. `standard_parameters_lithium_ion` is now `LithiumIonParameters`) ([#1120](https://github.com/pybamm-team/PyBaMM/pull/1120))
- Renamed `quick_plot_vars` to `output_variables` in `Simulation` to be consistent with `QuickPlot`. Passing `quick_plot_vars` to `Simulation.plot()` has been deprecated and `output_variables` should be passed instead ([#1099](https://github.com/pybamm-team/PyBaMM/pull/1099))


# [v0.2.3](https://github.com/pybamm-team/PyBaMM/tree/v0.2.3) - 2020-07-01

This release enables the use of [Google Colab](https://colab.research.google.com/github/pybamm-team/PyBaMM/blob/main/) for running example notebooks, and adds some small new features and bug fixes.
Expand Down
4 changes: 2 additions & 2 deletions FindSUNDIALS.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# find the SUNDIALS include directories
find_path(SUNDIALS_INCLUDE_DIR
NAMES
ida/ida.h
idas/idas.h
sundials/sundials_math.h
sundials/sundials_types.h
sunlinsol/sunlinsol_klu.h
Expand All @@ -39,7 +39,7 @@ find_path(SUNDIALS_INCLUDE_DIR
)

set(SUNDIALS_WANT_COMPONENTS
sundials_ida
sundials_idas
sundials_sunlinsolklu
sundials_sunmatrixsparse
sundials_nvecserial
Expand Down
3 changes: 0 additions & 3 deletions docs/source/solvers/processed_variable.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,3 @@ Post-Process Variables

.. autoclass:: pybamm.ProcessedVariable
:members:

.. autoclass:: pybamm.ProcessedSymbolicVariable
:members:
5 changes: 5 additions & 0 deletions pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,11 @@ def process_model(self, model, inplace=True, check_model=True):
model_disc.rhs, model_disc.concatenated_rhs = rhs, concat_rhs
model_disc.algebraic, model_disc.concatenated_algebraic = alg, concat_alg

# Save length of rhs and algebraic
model_disc.len_rhs = model_disc.concatenated_rhs.size
model_disc.len_alg = model_disc.concatenated_algebraic.size
model_disc.len_rhs_and_alg = model_disc.len_rhs + model_disc.len_alg

# Process events
processed_events = []
pybamm.logger.verbose("Discretise events for {}".format(model.name))
Expand Down
12 changes: 12 additions & 0 deletions pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ def __str__(self):
out = out[:-2] + ")"
return out

def _diff(self, variable):
""" See :meth:`pybamm.Symbol._diff()`. """
children_diffs = [
child.diff(variable) for child in self.cached_children
]
if len(children_diffs) == 1:
diff = children_diffs[0]
else:
diff = self.__class__(*children_diffs)

return diff

def get_children_domains(self, children):
# combine domains from children
domain = []
Expand Down
60 changes: 49 additions & 11 deletions pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def __init__(self, symbol):
constants[symbol_id] = jax.device_put(constants[symbol_id])

# get a list of constant arguments to input to the function
arg_list = [
self._arg_list = [
id_to_python_variable(symbol_id, True) for symbol_id in constants.keys()
]

Expand All @@ -580,9 +580,11 @@ def __init__(self, symbol):

# add function def to first line
args = "t=None, y=None, y_dot=None, inputs=None, known_evals=None"
if arg_list:
args = ",".join(arg_list) + ", " + args
python_str = "def evaluate_jax({}):\n".format(args) + python_str
if self._arg_list:
args = ",".join(self._arg_list) + ", " + args
python_str = (
"def evaluate_jax({}):\n".format(args) + python_str
)

# calculate the final variable that will output the result of calling `evaluate`
# on `symbol`
Expand All @@ -606,17 +608,32 @@ def __init__(self, symbol):
compiled_function = compile(python_str, result_var, "exec")
exec(compiled_function)

n = len(arg_list)
static_argnums = tuple(static_argnums)
self._jit_evaluate = jax.jit(self._evaluate_jax, static_argnums=static_argnums)
self._static_argnums = tuple(static_argnums)
self._jit_evaluate = jax.jit(self._evaluate_jax,
static_argnums=self._static_argnums)

def get_jacobian(self):
n = len(self._arg_list)

# store a jit version of evaluate_jax's jacobian
# forward mode autodiff wrt y, which is argument 1 after arg_list
jacobian_evaluate = jax.jacfwd(self._evaluate_jax, argnums=1 + n)
self._jac_evaluate = jax.jit(jacobian_evaluate, static_argnums=static_argnums)

def get_jacobian(self):
self._jac_evaluate = jax.jit(jacobian_evaluate,
static_argnums=self._static_argnums)

return EvaluatorJaxJacobian(self._jac_evaluate, self._constants)

def get_sensitivities(self):
n = len(self._arg_list)

# forward mode autodiff wrt inputs, which is argument 3 after arg_list
jacobian_evaluate = jax.jacfwd(self._evaluate_jax, argnums=3 + n)

self._sens_evaluate = jax.jit(jacobian_evaluate,
static_argnums=self._static_argnums)

return EvaluatorJaxSensitivities(self._sens_evaluate, self._constants)

def debug(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
# generated code assumes y is a column vector
if y is not None and y.ndim == 1:
Expand Down Expand Up @@ -668,7 +685,28 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
result = self._jac_evaluate(*self._constants, t, y, y_dot, inputs, known_evals)
result = result.reshape(result.shape[0], -1)

# don't need known_evals, but need to reproduce Symbol.evaluate signature
if known_evals is not None:
return result, known_evals
else:
return result


class EvaluatorJaxSensitivities:
def __init__(self, jac_evaluate, constants):
self._jac_evaluate = jac_evaluate
self._constants = constants

def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
"""
Acts as a drop-in replacement for :func:`pybamm.Symbol.evaluate`
"""
# generated code assumes y is a column vector
if y is not None and y.ndim == 1:
y = y.reshape(-1, 1)

# execute code
result = self._jac_evaluate(*self._constants, t, y, y_dot, inputs, known_evals)

if known_evals is not None:
return result, known_evals
else:
Expand Down
13 changes: 13 additions & 0 deletions pybamm/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,19 @@ def timescale(self, value):
"""Set the timescale"""
self._timescale = value

@property
def length_scales(self):
"Length scales of model"
return self._length_scale

@length_scales.setter
def length_scales(self, values):
"Set the length scale, converting any numbers to pybamm.Scalar"
for domain, scale in values.items():
if isinstance(scale, numbers.Number):
values[domain] = pybamm.Scalar(scale)
self._length_scale = values

@property
def parameters(self):
"""Returns all the parameters in the model"""
Expand Down
Loading

0 comments on commit 404fa2e

Please sign in to comment.