Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Idaklu solver can be given a list of variables to calculate during the solve #3217

Merged
merged 45 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
36a7275
Refactor of idaklu code
jsbrittain Jul 14, 2023
e693cef
Add convenience test code (quick compile & run)
jsbrittain Jul 14, 2023
97e7e0d
Add CSR support to idaklu solver
jsbrittain Jul 14, 2023
201ec7f
Add initial support for a list of variables' functions in idaklu
jsbrittain Jul 15, 2023
74def04
Collate vars or state vectors in idaklu; refactor Solution class stru…
jsbrittain Jul 16, 2023
df9568d
Fix incorrect input structure being passed when building variables' c…
jsbrittain Jul 18, 2023
9cb6cd3
Calculate sensitivities in Idaklu; uses temporary return structure
jsbrittain Jul 19, 2023
c8ad3c4
Provide new ProcessedVariable class when output_variables are specified
jsbrittain Jul 20, 2023
d7de190
Variables now correctly unroll in the ProcessedVariablesVar class
jsbrittain Jul 20, 2023
ef90a61
Reshape sensitivities using base variable parameters
jsbrittain Jul 20, 2023
c35fbae
Ensure further compatibility with existing tests
jsbrittain Jul 21, 2023
80d845e
Add unit tests for ProcessedVariableVar and modify Solution class tes…
jsbrittain Jul 21, 2023
7e3ea66
Merge branch 'pybamm-team:develop' into solver
jsbrittain Jul 21, 2023
3dfbc15
Account for ExplicitTimeIntegral variables
jsbrittain Aug 1, 2023
f09ca6d
Remove developer quicktest code
jsbrittain Aug 1, 2023
bc2e162
Code formatting
jsbrittain Aug 1, 2023
aadc287
Update changelog
jsbrittain Aug 1, 2023
1c83d00
Merge branch 'pybamm-team:develop' into solver
jsbrittain Aug 1, 2023
16c80ca
Resolve merge conflict
jsbrittain Aug 31, 2023
58cee45
Pre-commit tidy-up
jsbrittain Aug 31, 2023
2fb4e55
Improve testing / code-coverage
jsbrittain Aug 31, 2023
437a488
Codacy improvements
jsbrittain Aug 31, 2023
74f06f6
Codacy improvements
jsbrittain Aug 31, 2023
89222a5
Additional tests for idaklu solver with output_variables
jsbrittain Aug 31, 2023
d028668
Improve test coverage
jsbrittain Sep 1, 2023
86b9c01
Generalise idaklu matrix format conversions
jsbrittain Sep 1, 2023
b9b6d03
Merge branch 'pybamm-team:develop' into solver
jsbrittain Sep 7, 2023
653dc1d
Suggested changes
jsbrittain Sep 8, 2023
f442fb1
Improve documentation in idaklu solver
jsbrittain Sep 11, 2023
38e4a06
Refactor CasadiFunctionsOpenMP concrete implementations to macro gene…
jsbrittain Sep 11, 2023
bb56bea
Reinforce appropriate access rights (provide getters) for function ac…
jsbrittain Sep 11, 2023
21d68be
Rename ProcessedVariableVar class to ProcessedVariableComputed and up…
jsbrittain Sep 11, 2023
ab98190
Move casadi variable and sensitivity functions to the BaseSolver class
jsbrittain Sep 11, 2023
2e59e71
Merge branch 'pybamm-team:develop' into solver
jsbrittain Sep 11, 2023
f9b709a
Update test functions for ProcessedVariableComputed class
jsbrittain Sep 11, 2023
c3e4218
Fix IdakluSolver issue overwriting casadi functions and add additiona…
jsbrittain Sep 11, 2023
5c52220
style: pre-commit fixes
pre-commit-ci[bot] Sep 11, 2023
2187b89
Fix IDAKLUSolver crash on successive runs
jsbrittain Sep 12, 2023
7f9aa24
Refactor variable generation through the standard process() function
jsbrittain Sep 13, 2023
f381e31
Pre-commit fixes
jsbrittain Sep 13, 2023
174bec4
Streamline base_solver process()
jsbrittain Sep 18, 2023
67f825b
Merge branch 'pybamm-team:develop' into solver
jsbrittain Sep 19, 2023
e8ff644
Separate idaklu implementation classes
jsbrittain Sep 19, 2023
9ac8383
Codacy improvements
jsbrittain Sep 22, 2023
c7d109b
Merge branch 'pybamm-team:develop' into solver
jsbrittain Sep 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

## Features

- Idaklu solver can be given a list of variables to calculate during the solve ([#3217](https://github.com/pybamm-team/PyBaMM/pull/3217))
- Enable multithreading in IDAKLU solver ([#2947](https://github.com/pybamm-team/PyBaMM/pull/2947))
- If a solution contains cycles and steps, the cycle number and step number are now saved when `solution.save_data()` is called ([#2931](https://github.com/pybamm-team/PyBaMM/pull/2931))
- Experiments can now be given a `start_time` to define when each step should be triggered ([#2616](https://github.com/pybamm-team/PyBaMM/pull/2616))
Expand Down
8 changes: 7 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,14 @@ pybind11_add_module(idaklu
pybamm/solvers/c_solvers/idaklu/casadi_functions.hpp
pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp
pybamm/solvers/c_solvers/idaklu/casadi_solver.hpp
pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.hpp
pybamm/solvers/c_solvers/idaklu/CasadiSolver.cpp
pybamm/solvers/c_solvers/idaklu/CasadiSolver.hpp
pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.cpp
pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP.hpp
pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.cpp
pybamm/solvers/c_solvers/idaklu/CasadiSolverOpenMP_solvers.hpp
pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.cpp
pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.hpp
pybamm/solvers/c_solvers/idaklu/common.hpp
pybamm/solvers/c_solvers/idaklu/python.hpp
pybamm/solvers/c_solvers/idaklu/python.cpp
Expand Down
1 change: 1 addition & 0 deletions pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@
#
from .solvers.solution import Solution, EmptySolution, make_cycle_solution
from .solvers.processed_variable import ProcessedVariable
from .solvers.processed_variable_computed import ProcessedVariableComputed
from .solvers.base_solver import BaseSolver
from .solvers.dummy_solver import DummySolver
from .solvers.algebraic_solver import AlgebraicSolver
Expand Down
94 changes: 82 additions & 12 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class BaseSolver(object):
The tolerance for the initial-condition solver (default is 1e-6).
extrap_tol : float, optional
The tolerance to assert whether extrapolation occurs or not. Default is 0.
output_variables : list[str], optional
List of variables to calculate and return. If none are specified then
the complete state vector is returned (can be very large) (default is [])
"""

def __init__(
Expand All @@ -48,20 +51,23 @@ def __init__(
root_method=None,
root_tol=1e-6,
extrap_tol=None,
output_variables=[],
):
self.method = method
self.rtol = rtol
self.atol = atol
self.root_tol = root_tol
self.root_method = root_method
self.extrap_tol = extrap_tol or -1e-10
self.output_variables = output_variables
self._model_set_up = {}

# Defaults, can be overwritten by specific solver
self.name = "Base solver"
self.ode_solver = False
self.algebraic_solver = False
self._on_extrapolation = "warn"
self.computed_var_fcns = {}

@property
def root_method(self):
Expand Down Expand Up @@ -250,8 +256,57 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
model.casadi_sensitivities_rhs = jacp_rhs
model.casadi_sensitivities_algebraic = jacp_algebraic

# if output_variables specified then convert functions to casadi
# expressions for evaluation within the respective solver
self.computed_var_fcns = {}
self.computed_dvar_dy_fcns = {}
self.computed_dvar_dp_fcns = {}
for key in self.output_variables:
# ExplicitTimeIntegral's are not computed as part of the solver and
# do not need to be converted
if isinstance(
model.variables_and_events[key], pybamm.ExplicitTimeIntegral
):
continue
# Generate Casadi function to calculate variable and derivates
# to enable sensitivites to be computed within the solver
(
self.computed_var_fcns[key],
self.computed_dvar_dy_fcns[key],
self.computed_dvar_dp_fcns[key],
_,
) = process(
model.variables_and_events[key],
BaseSolver._wrangle_name(key),
vars_for_processing,
use_jacobian=True,
return_jacp_stacked=True,
)

pybamm.logger.info("Finish solver set-up")

@classmethod
def _wrangle_name(cls, name: str) -> str:
"""
Wrangle a function name to replace special characters
"""
replacements = [
(" ", "_"),
("[", ""),
("]", ""),
(".", "_"),
("-", "_"),
("(", ""),
(")", ""),
("%", "prc"),
(",", ""),
(".", ""),
]
name = "v_" + name.casefold()
for string, replacement in replacements:
name = name.replace(string, replacement)
return name

def _check_and_prepare_model_inplace(self, model, inputs, ics_only):
"""
Performs checks on the model and prepares it for solving.
Expand Down Expand Up @@ -1366,7 +1421,9 @@ def _set_up_model_inputs(self, model, inputs):
return ordered_inputs


def process(symbol, name, vars_for_processing, use_jacobian=None):
def process(
symbol, name, vars_for_processing, use_jacobian=None, return_jacp_stacked=None
):
"""
Parameters
----------
Expand All @@ -1376,6 +1433,8 @@ def process(symbol, name, vars_for_processing, use_jacobian=None):
function evaluators created will have this base name
use_jacobian: bool, optional
whether to return Jacobian functions
return_jacp_stacked: bool, optional
returns Jacobian function wrt stacked parameters instead of jacp

Returns
-------
Expand Down Expand Up @@ -1553,17 +1612,28 @@ def jacp(*args, **kwargs):
"CasADi"
)
)
# WARNING, jacp for convert_to_format=casadi does not return a dict
# instead it returns multiple return values, one for each param
# TODO: would it be faster to do the jacobian wrt pS_casadi_stacked?
jacp = casadi.Function(
name + "_jacp",
[t_casadi, y_and_S, p_casadi_stacked],
[
casadi.densify(casadi.jacobian(casadi_expression, p_casadi[pname]))
for pname in model.calculate_sensitivities
],
)
# Compute derivate wrt p-stacked (can be passed to solver to
# compute sensitivities online)
if return_jacp_stacked:
jacp = casadi.Function(
f"d{name}_dp",
[t_casadi, y_casadi, p_casadi_stacked],
[casadi.jacobian(casadi_expression, p_casadi_stacked)],
)
else:
# WARNING, jacp for convert_to_format=casadi does not return a dict
# instead it returns multiple return values, one for each param
# TODO: would it be faster to do the jacobian wrt pS_casadi_stacked?
jacp = casadi.Function(
name + "_jacp",
[t_casadi, y_and_S, p_casadi_stacked],
[
casadi.densify(
casadi.jacobian(casadi_expression, p_casadi[pname])
)
for pname in model.calculate_sensitivities
],
)

if use_jacobian:
report(f"Calculating jacobian for {name} using CasADi")
Expand Down
91 changes: 64 additions & 27 deletions pybamm/solvers/c_solvers/idaklu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <pybind11/functional.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>

#include <vector>
Expand All @@ -25,39 +26,75 @@ PYBIND11_MODULE(idaklu, m)
py::bind_vector<std::vector<np_array>>(m, "VectorNdArray");

m.def("solve_python", &solve_python,
"The solve function for python evaluators", py::arg("t"), py::arg("y0"),
py::arg("yp0"), py::arg("res"), py::arg("jac"), py::arg("sens"),
py::arg("get_jac_data"), py::arg("get_jac_row_vals"),
py::arg("get_jac_col_ptr"), py::arg("nnz"), py::arg("events"),
py::arg("number_of_events"), py::arg("use_jacobian"),
py::arg("rhs_alg_id"), py::arg("atol"), py::arg("rtol"),
py::arg("inputs"), py::arg("number_of_sensitivity_parameters"),
py::return_value_policy::take_ownership);
"The solve function for python evaluators",
py::arg("t"),
py::arg("y0"),
py::arg("yp0"),
py::arg("res"),
py::arg("jac"),
py::arg("sens"),
py::arg("get_jac_data"),
py::arg("get_jac_row_vals"),
py::arg("get_jac_col_ptr"),
py::arg("nnz"),
py::arg("events"),
py::arg("number_of_events"),
py::arg("use_jacobian"),
py::arg("rhs_alg_id"),
py::arg("atol"),
py::arg("rtol"),
py::arg("inputs"),
py::arg("number_of_sensitivity_parameters"),
py::return_value_policy::take_ownership);

py::class_<CasadiSolver>(m, "CasadiSolver")
.def("solve", &CasadiSolver::solve, "perform a solve", py::arg("t"),
py::arg("y0"), py::arg("yp0"), py::arg("inputs"),
py::return_value_policy::take_ownership);
.def("solve", &CasadiSolver::solve,
"perform a solve",
py::arg("t"),
py::arg("y0"),
py::arg("yp0"),
py::arg("inputs"),
py::return_value_policy::take_ownership);

//py::bind_vector<std::vector<Function>>(m, "VectorFunction");
//py::implicitly_convertible<py::iterable, std::vector<Function>>();

m.def("create_casadi_solver", &create_casadi_solver,
"Create a casadi idaklu solver object", py::arg("number_of_states"),
py::arg("number_of_parameters"), py::arg("rhs_alg"),
py::arg("jac_times_cjmass"), py::arg("jac_times_cjmass_colptrs"),
py::arg("jac_times_cjmass_rowvals"), py::arg("jac_times_cjmass_nnz"),
py::arg("jac_bandwidth_lower"), py::arg("jac_bandwidth_upper"),
py::arg("jac_action"), py::arg("mass_action"), py::arg("sens"),
py::arg("events"), py::arg("number_of_events"), py::arg("rhs_alg_id"),
py::arg("atol"), py::arg("rtol"), py::arg("inputs"), py::arg("options"),
py::return_value_policy::take_ownership);

m.def("generate_function", &generate_function, "Generate a casadi function",
py::arg("string"), py::return_value_policy::take_ownership);
"Create a casadi idaklu solver object",
py::arg("number_of_states"),
py::arg("number_of_parameters"),
py::arg("rhs_alg"),
py::arg("jac_times_cjmass"),
py::arg("jac_times_cjmass_colptrs"),
py::arg("jac_times_cjmass_rowvals"),
py::arg("jac_times_cjmass_nnz"),
py::arg("jac_bandwidth_lower"),
py::arg("jac_bandwidth_upper"),
py::arg("jac_action"),
py::arg("mass_action"),
py::arg("sens"),
py::arg("events"),
py::arg("number_of_events"),
py::arg("rhs_alg_id"),
py::arg("atol"),
py::arg("rtol"),
py::arg("inputs"),
py::arg("var_casadi_fcns"),
py::arg("dvar_dy_fcns"),
py::arg("dvar_dp_fcns"),
py::arg("options"),
py::return_value_policy::take_ownership);

m.def("generate_function", &generate_function,
"Generate a casadi function",
py::arg("string"),
py::return_value_policy::take_ownership);

py::class_<Function>(m, "Function");

py::class_<Solution>(m, "solution")
.def_readwrite("t", &Solution::t)
.def_readwrite("y", &Solution::y)
.def_readwrite("yS", &Solution::yS)
.def_readwrite("flag", &Solution::flag);
.def_readwrite("t", &Solution::t)
.def_readwrite("y", &Solution::y)
.def_readwrite("yS", &Solution::yS)
.def_readwrite("flag", &Solution::flag);
}
1 change: 1 addition & 0 deletions pybamm/solvers/c_solvers/idaklu/CasadiSolver.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#include "CasadiSolver.hpp"
49 changes: 49 additions & 0 deletions pybamm/solvers/c_solvers/idaklu/CasadiSolver.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#ifndef PYBAMM_IDAKLU_CASADI_SOLVER_HPP
#define PYBAMM_IDAKLU_CASADI_SOLVER_HPP

#include <casadi/casadi.hpp>
using Function = casadi::Function;

#include "casadi_functions.hpp"
#include "common.hpp"
#include "options.hpp"
#include "solution.hpp"
#include "sundials_legacy_wrapper.hpp"

/**
* Abstract base class for solutions that can use different solvers and vector
* implementations.
* @brief An abstract base class for the Idaklu solver
*/
class CasadiSolver
{
public:

/**
* @brief Default constructor
*/
CasadiSolver() = default;

/**
* @brief Default destructor
*/
~CasadiSolver() = default;

/**
* @brief Abstract solver method that returns a Solution class
*/
virtual Solution solve(
np_array t_np,
np_array y0_np,
np_array yp0_np,
np_array_dense inputs) = 0;

/**
* Abstract method to initialize the solver, once vectors and solver classes
* are set
* @brief Abstract initialization method
*/
virtual void Initialize() = 0;
};

#endif // PYBAMM_IDAKLU_CASADI_SOLVER_HPP
Loading