Skip to content

Commit

Permalink
Fast Hermite interpolation and observables (#4464)
Browse files Browse the repository at this point in the history
* fast observables

* fix release and types

* good

* faster interp

* `double` -> `realtype`

* clean separation

* good again

* private members

* cleanup

* fix codecov, tests

* naming

* codecov

* codacy, cse/expand

* fix `try/except`

* Update CHANGELOG.md

* address comments

* initialize `save_hermite`

* fix codecov

---------

Co-authored-by: Eric G. Kratz <[email protected]>
  • Loading branch information
MarcBerliner and kratman authored Oct 2, 2024
1 parent 08e5cf2 commit 444ecc1
Show file tree
Hide file tree
Showing 26 changed files with 2,021 additions and 610 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Features

- Added Hermite interpolation to the (`IDAKLUSolver`) that improves the accuracy and performance of post-processing variables. ([#4464](https://github.com/pybamm-team/PyBaMM/pull/4464))
- Added sensitivity calculation support for `pybamm.Simulation` and `pybamm.Experiment` ([#4415](https://github.com/pybamm-team/PyBaMM/pull/4415))
- Added OpenMP parallelization to IDAKLU solver for lists of input parameters ([#4449](https://github.com/pybamm-team/PyBaMM/pull/4449))
- Added phase-dependent particle options to LAM
Expand Down
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ pybind11_add_module(idaklu
src/pybamm/solvers/c_solvers/idaklu/Expressions/Base/Expression.hpp
src/pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionSet.hpp
src/pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionTypes.hpp
src/pybamm/solvers/c_solvers/idaklu/observe.hpp
src/pybamm/solvers/c_solvers/idaklu/observe.cpp
# IDAKLU expressions - concrete implementations
${IDAKLU_EXPR_CASADI_SOURCE_FILES}
${IDAKLU_EXPR_IREE_SOURCE_FILES}
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@ def compile_KLU():
"src/pybamm/solvers/c_solvers/idaklu/Solution.hpp",
"src/pybamm/solvers/c_solvers/idaklu/Options.hpp",
"src/pybamm/solvers/c_solvers/idaklu/Options.cpp",
"src/pybamm/solvers/c_solvers/idaklu/observe.hpp",
"src/pybamm/solvers/c_solvers/idaklu/observe.cpp",
"src/pybamm/solvers/c_solvers/idaklu.cpp",
],
)
Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@

# Solver classes
from .solvers.solution import Solution, EmptySolution, make_cycle_solution
from .solvers.processed_variable import ProcessedVariable
from .solvers.processed_variable import ProcessedVariable, process_variable
from .solvers.processed_variable_computed import ProcessedVariableComputed
from .solvers.base_solver import BaseSolver
from .solvers.dummy_solver import DummySolver
Expand Down
17 changes: 8 additions & 9 deletions src/pybamm/plotting/quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,14 +419,14 @@ def reset_axis(self):
spatial_vars = self.spatial_variable_dict[key]
var_min = np.min(
[
ax_min(var(self.ts_seconds[i], **spatial_vars, warn=False))
ax_min(var(self.ts_seconds[i], **spatial_vars))
for i, variable_list in enumerate(variable_lists)
for var in variable_list
]
)
var_max = np.max(
[
ax_max(var(self.ts_seconds[i], **spatial_vars, warn=False))
ax_max(var(self.ts_seconds[i], **spatial_vars))
for i, variable_list in enumerate(variable_lists)
for var in variable_list
]
Expand Down Expand Up @@ -512,7 +512,7 @@ def plot(self, t, dynamic=False):
full_t = self.ts_seconds[i]
(self.plots[key][i][j],) = ax.plot(
full_t / self.time_scaling_factor,
variable(full_t, warn=False),
variable(full_t),
color=self.colors[i],
linestyle=linestyle,
)
Expand Down Expand Up @@ -548,7 +548,7 @@ def plot(self, t, dynamic=False):
linestyle = self.linestyles[j]
(self.plots[key][i][j],) = ax.plot(
self.first_spatial_variable[key],
variable(t_in_seconds, **spatial_vars, warn=False),
variable(t_in_seconds, **spatial_vars),
color=self.colors[i],
linestyle=linestyle,
zorder=10,
Expand All @@ -570,13 +570,13 @@ def plot(self, t, dynamic=False):
y_name = next(iter(spatial_vars.keys()))[0]
x = self.second_spatial_variable[key]
y = self.first_spatial_variable[key]
var = variable(t_in_seconds, **spatial_vars, warn=False)
var = variable(t_in_seconds, **spatial_vars)
else:
x_name = next(iter(spatial_vars.keys()))[0]
y_name = list(spatial_vars.keys())[1][0]
x = self.first_spatial_variable[key]
y = self.second_spatial_variable[key]
var = variable(t_in_seconds, **spatial_vars, warn=False).T
var = variable(t_in_seconds, **spatial_vars).T
ax.set_xlabel(f"{x_name} [{self.spatial_unit}]")
ax.set_ylabel(f"{y_name} [{self.spatial_unit}]")
vmin, vmax = self.variable_limits[key]
Expand Down Expand Up @@ -710,7 +710,6 @@ def slider_update(self, t):
var = variable(
time_in_seconds,
**self.spatial_variable_dict[key],
warn=False,
)
plot[i][j].set_ydata(var)
var_min = min(var_min, ax_min(var))
Expand All @@ -729,11 +728,11 @@ def slider_update(self, t):
if self.x_first_and_y_second[key] is False:
x = self.second_spatial_variable[key]
y = self.first_spatial_variable[key]
var = variable(time_in_seconds, **spatial_vars, warn=False)
var = variable(time_in_seconds, **spatial_vars)
else:
x = self.first_spatial_variable[key]
y = self.second_spatial_variable[key]
var = variable(time_in_seconds, **spatial_vars, warn=False).T
var = variable(time_in_seconds, **spatial_vars).T
# store the plot and the var data (for testing) as cant access
# z data from QuadMesh or QuadContourSet object
if self.is_y_z[key] is True:
Expand Down
26 changes: 26 additions & 0 deletions src/pybamm/solvers/c_solvers/idaklu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <pybind11/stl_bind.h>

#include "idaklu/idaklu_solver.hpp"
#include "idaklu/observe.hpp"
#include "idaklu/IDAKLUSolverGroup.hpp"
#include "idaklu/IdakluJax.hpp"
#include "idaklu/common.hpp"
Expand All @@ -27,13 +28,15 @@ casadi::Function generate_casadi_function(const std::string &data)
namespace py = pybind11;

PYBIND11_MAKE_OPAQUE(std::vector<np_array>);
PYBIND11_MAKE_OPAQUE(std::vector<np_array_realtype>);
PYBIND11_MAKE_OPAQUE(std::vector<Solution>);

PYBIND11_MODULE(idaklu, m)
{
m.doc() = "sundials solvers"; // optional module docstring

py::bind_vector<std::vector<np_array>>(m, "VectorNdArray");
py::bind_vector<std::vector<np_array_realtype>>(m, "VectorRealtypeNdArray");
py::bind_vector<std::vector<Solution>>(m, "VectorSolution");

py::class_<IDAKLUSolverGroup>(m, "IDAKLUSolverGroup")
Expand Down Expand Up @@ -72,6 +75,27 @@ PYBIND11_MODULE(idaklu, m)
py::arg("options"),
py::return_value_policy::take_ownership);

m.def("observe", &observe,
"Observe variables",
py::arg("ts"),
py::arg("ys"),
py::arg("inputs"),
py::arg("funcs"),
py::arg("is_f_contiguous"),
py::arg("shape"),
py::return_value_policy::take_ownership);

m.def("observe_hermite_interp", &observe_hermite_interp,
"Observe and Hermite interpolate variables",
py::arg("t_interp"),
py::arg("ts"),
py::arg("ys"),
py::arg("yps"),
py::arg("inputs"),
py::arg("funcs"),
py::arg("shape"),
py::return_value_policy::take_ownership);

#ifdef IREE_ENABLE
m.def("create_iree_solver_group", &create_idaklu_solver_group<IREEFunctions>,
"Create a group of iree idaklu solver objects",
Expand Down Expand Up @@ -167,7 +191,9 @@ PYBIND11_MODULE(idaklu, m)
py::class_<Solution>(m, "solution")
.def_readwrite("t", &Solution::t)
.def_readwrite("y", &Solution::y)
.def_readwrite("yp", &Solution::yp)
.def_readwrite("yS", &Solution::yS)
.def_readwrite("ypS", &Solution::ypS)
.def_readwrite("y_term", &Solution::y_term)
.def_readwrite("flag", &Solution::flag);
}
43 changes: 40 additions & 3 deletions src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver
int const number_of_events; // cppcheck-suppress unusedStructMember
int number_of_timesteps;
int precon_type; // cppcheck-suppress unusedStructMember
N_Vector yy, yp, y_cache, avtol; // y, y', y cache vector, and absolute tolerance
N_Vector yy, yyp, y_cache, avtol; // y, y', y cache vector, and absolute tolerance
N_Vector *yyS; // cppcheck-suppress unusedStructMember
N_Vector *ypS; // cppcheck-suppress unusedStructMember
N_Vector *yypS; // cppcheck-suppress unusedStructMember
N_Vector id; // rhs_alg_id
realtype rtol;
int const jac_times_cjmass_nnz; // cppcheck-suppress unusedStructMember
Expand All @@ -70,11 +70,14 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver
vector<realtype> res_dvar_dp;
bool const sensitivity; // cppcheck-suppress unusedStructMember
bool const save_outputs_only; // cppcheck-suppress unusedStructMember
bool save_hermite; // cppcheck-suppress unusedStructMember
bool is_ODE; // cppcheck-suppress unusedStructMember
int length_of_return_vector; // cppcheck-suppress unusedStructMember
vector<realtype> t; // cppcheck-suppress unusedStructMember
vector<vector<realtype>> y; // cppcheck-suppress unusedStructMember
vector<vector<realtype>> yp; // cppcheck-suppress unusedStructMember
vector<vector<vector<realtype>>> yS; // cppcheck-suppress unusedStructMember
vector<vector<vector<realtype>>> ypS; // cppcheck-suppress unusedStructMember
SetupOptions const setup_opts;
SolverOptions const solver_opts;

Expand Down Expand Up @@ -144,6 +147,11 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver
*/
void InitializeStorage(int const N);

/**
* @brief Initialize the storage for Hermite interpolation
*/
void InitializeHermiteStorage(int const N);

/**
* @brief Apply user-configurable IDA options
*/
Expand Down Expand Up @@ -190,13 +198,20 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver
*/
void ExtendAdaptiveArrays();

/**
* @brief Extend the Hermite interpolation info by 1
*/
void ExtendHermiteArrays();

/**
* @brief Set the step values
*/
void SetStep(
realtype &t_val,
realtype &tval,
realtype *y_val,
realtype *yp_val,
vector<realtype *> const &yS_val,
vector<realtype *> const &ypS_val,
int &i_save
);

Expand All @@ -211,7 +226,9 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver
realtype &t_prev,
realtype const &t_next,
realtype *y_val,
realtype *yp_val,
vector<realtype *> const &yS_val,
vector<realtype *> const &ypS_val,
int &i_save
);

Expand Down Expand Up @@ -255,6 +272,26 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver
int &i_save
);

/**
* @brief Save the output function results at the requested time
*/
void SetStepHermite(
realtype &t_val,
realtype *yp_val,
const vector<realtype*> &ypS_val,
int &i_save
);

/**
* @brief Save the output function sensitivities at the requested time
*/
void SetStepHermiteSensitivities(
realtype &t_val,
realtype *yp_val,
const vector<realtype*> &ypS_val,
int &i_save
);

};

#include "IDAKLUSolverOpenMP.inl"
Expand Down
Loading

0 comments on commit 444ecc1

Please sign in to comment.