Skip to content

Commit

Permalink
Merge pull request #2947 from jsbrittain/openmp
Browse files Browse the repository at this point in the history
Enable multithreading in IDAKLU
  • Loading branch information
martinjrobins authored May 17, 2023
2 parents 6b0615b + 1b9a5cd commit 137ae23
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 29 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Features

- Enable multithreading in IDAKLU solver ([#2947](https://github.com/pybamm-team/PyBaMM/pull/2947))
- If a solution contains cycles and steps, the cycle number and step number are now saved when `solution.save_data()` is called ([#2931](https://github.com/pybamm-team/PyBaMM/pull/2931))

## Optimizations
Expand Down
1 change: 1 addition & 0 deletions FindSUNDIALS.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ set(SUNDIALS_WANT_COMPONENTS
sundials_sunlinsollapackdense
sundials_sunmatrixsparse
sundials_nvecserial
sundials_nvecopenmp
)

# find the SUNDIALS libraries
Expand Down
18 changes: 10 additions & 8 deletions pybamm/solvers/c_solvers/idaklu/casadi_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "common.hpp"
#include <memory>


CasadiSolver *
create_casadi_solver(int number_of_states, int number_of_parameters,
const Function &rhs_alg, const Function &jac_times_cjmass,
Expand Down Expand Up @@ -53,16 +54,17 @@ CasadiSolver::CasadiSolver(np_array atol_np, double rel_tol,
#endif

// allocate vectors
int num_threads = options.num_threads;
#if SUNDIALS_VERSION_MAJOR >= 6
yy = N_VNew_Serial(number_of_states, sunctx);
yp = N_VNew_Serial(number_of_states, sunctx);
avtol = N_VNew_Serial(number_of_states, sunctx);
id = N_VNew_Serial(number_of_states, sunctx);
yy = N_VNew_OpenMP(number_of_states, num_threads, sunctx);
yp = N_VNew_OpenMP(number_of_states, num_threads, sunctx);
avtol = N_VNew_OpenMP(number_of_states, num_threads, sunctx);
id = N_VNew_OpenMP(number_of_states, num_threads, sunctx);
#else
yy = N_VNew_Serial(number_of_states);
yp = N_VNew_Serial(number_of_states);
avtol = N_VNew_Serial(number_of_states);
id = N_VNew_Serial(number_of_states);
yy = N_VNew_OpenMP(number_of_states, num_threads);
yp = N_VNew_OpenMP(number_of_states, num_threads);
avtol = N_VNew_OpenMP(number_of_states, num_threads);
id = N_VNew_OpenMP(number_of_states, num_threads);
#endif

if (number_of_parameters > 0)
Expand Down
36 changes: 18 additions & 18 deletions pybamm/solvers/c_solvers/idaklu/casadi_sundials_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,19 @@ int residual_casadi(realtype tres, N_Vector yy, N_Vector yp, N_Vector rr,
static_cast<CasadiFunctions *>(user_data);

p_python_functions->rhs_alg.m_arg[0] = &tres;
p_python_functions->rhs_alg.m_arg[1] = NV_DATA_S(yy);
p_python_functions->rhs_alg.m_arg[1] = NV_DATA_OMP(yy);
p_python_functions->rhs_alg.m_arg[2] = p_python_functions->inputs.data();
p_python_functions->rhs_alg.m_res[0] = NV_DATA_S(rr);
p_python_functions->rhs_alg.m_res[0] = NV_DATA_OMP(rr);
p_python_functions->rhs_alg();

realtype *tmp = p_python_functions->get_tmp_state_vector();
p_python_functions->mass_action.m_arg[0] = NV_DATA_S(yp);
p_python_functions->mass_action.m_arg[0] = NV_DATA_OMP(yp);
p_python_functions->mass_action.m_res[0] = tmp;
p_python_functions->mass_action();

// AXPY: y <- a*x + y
const int ns = p_python_functions->number_of_states;
casadi::casadi_axpy(ns, -1., tmp, NV_DATA_S(rr));
casadi::casadi_axpy(ns, -1., tmp, NV_DATA_OMP(rr));

DEBUG_VECTOR(yy);
DEBUG_VECTOR(yp);
Expand Down Expand Up @@ -101,22 +101,22 @@ int jtimes_casadi(realtype tt, N_Vector yy, N_Vector yp, N_Vector rr,

// Jv has ∂F/∂y v
p_python_functions->jac_action.m_arg[0] = &tt;
p_python_functions->jac_action.m_arg[1] = NV_DATA_S(yy);
p_python_functions->jac_action.m_arg[1] = NV_DATA_OMP(yy);
p_python_functions->jac_action.m_arg[2] = p_python_functions->inputs.data();
p_python_functions->jac_action.m_arg[3] = NV_DATA_S(v);
p_python_functions->jac_action.m_res[0] = NV_DATA_S(Jv);
p_python_functions->jac_action.m_arg[3] = NV_DATA_OMP(v);
p_python_functions->jac_action.m_res[0] = NV_DATA_OMP(Jv);
p_python_functions->jac_action();

// tmp has -∂F/∂y˙ v
realtype *tmp = p_python_functions->get_tmp_state_vector();
p_python_functions->mass_action.m_arg[0] = NV_DATA_S(v);
p_python_functions->mass_action.m_arg[0] = NV_DATA_OMP(v);
p_python_functions->mass_action.m_res[0] = tmp;
p_python_functions->mass_action();

// AXPY: y <- a*x + y
// Jv has ∂F/∂y v + cj ∂F/∂y˙ v
const int ns = p_python_functions->number_of_states;
casadi::casadi_axpy(ns, -cj, tmp, NV_DATA_S(Jv));
casadi::casadi_axpy(ns, -cj, tmp, NV_DATA_OMP(Jv));

return 0;
}
Expand Down Expand Up @@ -163,7 +163,7 @@ int jacobian_casadi(realtype tt, realtype cj, N_Vector yy, N_Vector yp,

// args are t, y, cj, put result in jacobian data matrix
p_python_functions->jac_times_cjmass.m_arg[0] = &tt;
p_python_functions->jac_times_cjmass.m_arg[1] = NV_DATA_S(yy);
p_python_functions->jac_times_cjmass.m_arg[1] = NV_DATA_OMP(yy);
p_python_functions->jac_times_cjmass.m_arg[2] =
p_python_functions->inputs.data();
p_python_functions->jac_times_cjmass.m_arg[3] = &cj;
Expand Down Expand Up @@ -227,7 +227,7 @@ int events_casadi(realtype t, N_Vector yy, N_Vector yp, realtype *events_ptr,

// args are t, y, put result in events_ptr
p_python_functions->events.m_arg[0] = &t;
p_python_functions->events.m_arg[1] = NV_DATA_S(yy);
p_python_functions->events.m_arg[1] = NV_DATA_OMP(yy);
p_python_functions->events.m_arg[2] = p_python_functions->inputs.data();
p_python_functions->events.m_res[0] = events_ptr;
p_python_functions->events();
Expand Down Expand Up @@ -270,11 +270,11 @@ int sensitivities_casadi(int Ns, realtype t, N_Vector yy, N_Vector yp,

// args are t, y put result in rr
p_python_functions->sens.m_arg[0] = &t;
p_python_functions->sens.m_arg[1] = NV_DATA_S(yy);
p_python_functions->sens.m_arg[1] = NV_DATA_OMP(yy);
p_python_functions->sens.m_arg[2] = p_python_functions->inputs.data();
for (int i = 0; i < np; i++)
{
p_python_functions->sens.m_res[i] = NV_DATA_S(resvalS[i]);
p_python_functions->sens.m_res[i] = NV_DATA_OMP(resvalS[i]);
}
// resvalsS now has (∂F/∂p i )
p_python_functions->sens();
Expand All @@ -284,23 +284,23 @@ int sensitivities_casadi(int Ns, realtype t, N_Vector yy, N_Vector yp,
// put (∂F/∂y)s i (t) in tmp
realtype *tmp = p_python_functions->get_tmp_state_vector();
p_python_functions->jac_action.m_arg[0] = &t;
p_python_functions->jac_action.m_arg[1] = NV_DATA_S(yy);
p_python_functions->jac_action.m_arg[1] = NV_DATA_OMP(yy);
p_python_functions->jac_action.m_arg[2] = p_python_functions->inputs.data();
p_python_functions->jac_action.m_arg[3] = NV_DATA_S(yS[i]);
p_python_functions->jac_action.m_arg[3] = NV_DATA_OMP(yS[i]);
p_python_functions->jac_action.m_res[0] = tmp;
p_python_functions->jac_action();

const int ns = p_python_functions->number_of_states;
casadi::casadi_axpy(ns, 1., tmp, NV_DATA_S(resvalS[i]));
casadi::casadi_axpy(ns, 1., tmp, NV_DATA_OMP(resvalS[i]));

// put -(∂F/∂ ẏ) ṡ i (t) in tmp2
p_python_functions->mass_action.m_arg[0] = NV_DATA_S(ypS[i]);
p_python_functions->mass_action.m_arg[0] = NV_DATA_OMP(ypS[i]);
p_python_functions->mass_action.m_res[0] = tmp;
p_python_functions->mass_action();

// (∂F/∂y)s i (t)+(∂F/∂ ẏ) ṡ i (t)+(∂F/∂p i )
// AXPY: y <- a*x + y
casadi::casadi_axpy(ns, -1., tmp, NV_DATA_S(resvalS[i]));
casadi::casadi_axpy(ns, -1., tmp, NV_DATA_OMP(resvalS[i]));
}

return 0;
Expand Down
1 change: 1 addition & 0 deletions pybamm/solvers/c_solvers/idaklu/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <idas/idas_bbdpre.h> /* access to IDABBDPRE preconditioner */

#include <nvector/nvector_serial.h> /* access to serial N_Vector */
#include <nvector/nvector_openmp.h> /* access to openmp N_Vector */
#include <sundials/sundials_math.h> /* defs. of SUNRabs, SUNRexp, etc. */
#include <sundials/sundials_config.h> /* defs. of SUNRabs, SUNRexp, etc. */
#include <sundials/sundials_types.h> /* defs. of realtype, sunindextype */
Expand Down
3 changes: 2 additions & 1 deletion pybamm/solvers/c_solvers/idaklu/options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ Options::Options(py::dict options)
linsol_max_iterations(options["linsol_max_iterations"].cast<int>()),
linear_solver(options["linear_solver"].cast<std::string>()),
precon_half_bandwidth(options["precon_half_bandwidth"].cast<int>()),
precon_half_bandwidth_keep(options["precon_half_bandwidth_keep"].cast<int>())
precon_half_bandwidth_keep(options["precon_half_bandwidth_keep"].cast<int>()),
num_threads(options["num_threads"].cast<int>())
{

using_sparse_matrix = true;
Expand Down
1 change: 1 addition & 0 deletions pybamm/solvers/c_solvers/idaklu/options.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ struct Options {
int linsol_max_iterations;
int precon_half_bandwidth;
int precon_half_bandwidth_keep;
int num_threads;
explicit Options(py::dict options);

};
Expand Down
4 changes: 4 additions & 0 deletions pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class IDAKLUSolver(pybamm.BaseSolver):
# for iterative linear solver preconditioner, bandwidth of
# approximate jacobian that is kept
"precon_half_bandwidth_keep": 5
# Number of threads available for OpenMP
"num_threads": 1
}
Note: These options only have an effect if model.convert_to_format == 'casadi'
Expand All @@ -100,6 +103,7 @@ def __init__(
"linsol_max_iterations": 5,
"precon_half_bandwidth": 5,
"precon_half_bandwidth_keep": 5,
"num_threads": 1,
}
if options is None:
options = default_options
Expand Down
5 changes: 3 additions & 2 deletions scripts/install_KLU_Sundials.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,11 @@ def download_extract_library(url, download_dir):
KLU_INCLUDE_DIR = os.path.join(install_dir, "include")
KLU_LIBRARY_DIR = os.path.join(install_dir, "lib")
cmake_args = [
"-DLAPACK_ENABLE=ON",
"-DENABLE_LAPACK=ON",
"-DSUNDIALS_INDEX_SIZE=32",
"-DEXAMPLES_ENABLE:BOOL=OFF",
"-DKLU_ENABLE=ON",
"-DENABLE_KLU=ON",
"-DENABLE_OPENMP=ON",
"-DKLU_INCLUDE_DIR={}".format(KLU_INCLUDE_DIR),
"-DKLU_LIBRARY_DIR={}".format(KLU_LIBRARY_DIR),
"-DCMAKE_INSTALL_PREFIX=" + install_dir,
Expand Down

0 comments on commit 137ae23

Please sign in to comment.