Skip to content

Commit

Permalink
feat: add OpenMP parallelization to IDAKLU solver for lists of input …
Browse files Browse the repository at this point in the history
…parameters (pybamm-team#4449)

* new solver option `num_solvers`, indicates how many solves run in parallel
* existing `num_threads` gives total number of threads which are distributed among `num_solvers`
  • Loading branch information
martinjrobins authored Sep 18, 2024
1 parent e1118ec commit 48dbb68
Show file tree
Hide file tree
Showing 20 changed files with 677 additions and 256 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Features
- Added sensitivity calculation support for `pybamm.Simulation` and `pybamm.Experiment` ([#4415](https://github.com/pybamm-team/PyBaMM/pull/4415))
- Added OpenMP parallelization to IDAKLU solver for lists of input parameters ([#4449](https://github.com/pybamm-team/PyBaMM/pull/4449))

## Optimizations
- Removed the `start_step_offset` setting and disabled minimum `dt` warnings for drive cycles with the (`IDAKLUSolver`). ([#4416](https://github.com/pybamm-team/PyBaMM/pull/4416))
Expand Down
23 changes: 22 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ endif()

project(idaklu)

set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
set(CMAKE_EXPORT_COMPILE_COMMANDS 1)
Expand Down Expand Up @@ -82,6 +82,8 @@ pybind11_add_module(idaklu
src/pybamm/solvers/c_solvers/idaklu/idaklu_solver.hpp
src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.cpp
src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp
src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverGroup.cpp
src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverGroup.hpp
src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl
src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp
src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP_solvers.cpp
Expand All @@ -94,6 +96,8 @@ pybind11_add_module(idaklu
src/pybamm/solvers/c_solvers/idaklu/common.cpp
src/pybamm/solvers/c_solvers/idaklu/Solution.cpp
src/pybamm/solvers/c_solvers/idaklu/Solution.hpp
src/pybamm/solvers/c_solvers/idaklu/SolutionData.cpp
src/pybamm/solvers/c_solvers/idaklu/SolutionData.hpp
src/pybamm/solvers/c_solvers/idaklu/Options.hpp
src/pybamm/solvers/c_solvers/idaklu/Options.cpp
# IDAKLU expressions / function evaluation [abstract]
Expand Down Expand Up @@ -138,6 +142,23 @@ set_target_properties(
INSTALL_RPATH_USE_LINK_PATH TRUE
)

# openmp
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
execute_process(
COMMAND "brew" "--prefix"
OUTPUT_VARIABLE HOMEBREW_PREFIX
OUTPUT_STRIP_TRAILING_WHITESPACE)
if (OpenMP_ROOT)
set(OpenMP_ROOT "${OpenMP_ROOT}:${HOMEBREW_PREFIX}/opt/libomp")
else()
set(OpenMP_ROOT "${HOMEBREW_PREFIX}/opt/libomp")
endif()
endif()
find_package(OpenMP)
if(OpenMP_CXX_FOUND)
target_link_libraries(idaklu PRIVATE OpenMP::OpenMP_CXX)
endif()

set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${PROJECT_SOURCE_DIR})
# Sundials
find_package(SUNDIALS REQUIRED)
Expand Down
59 changes: 35 additions & 24 deletions src/pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def supports_interp(self):
def root_method(self):
return self._root_method

@property
def supports_parallel_solve(self):
return False

@root_method.setter
def root_method(self, method):
if method == "casadi":
Expand Down Expand Up @@ -896,36 +900,37 @@ def solve(
pybamm.logger.verbose(
f"Calling solver for {t_eval[start_index]} < t < {t_eval[end_index - 1]}"
)
ninputs = len(model_inputs_list)
if ninputs == 1:
new_solution = self._integrate(
model,
t_eval[start_index:end_index],
model_inputs_list[0],
t_interp=t_interp,
)
new_solutions = [new_solution]
elif model.convert_to_format == "jax":
# Jax can parallelize over the inputs efficiently
if self.supports_parallel_solve:
# Jax and IDAKLU solver can accept a list of inputs
new_solutions = self._integrate(
model,
t_eval[start_index:end_index],
model_inputs_list,
t_interp,
)
else:
with mp.get_context(self._mp_context).Pool(processes=nproc) as p:
new_solutions = p.starmap(
self._integrate,
zip(
[model] * ninputs,
[t_eval[start_index:end_index]] * ninputs,
model_inputs_list,
[t_interp] * ninputs,
),
ninputs = len(model_inputs_list)
if ninputs == 1:
new_solution = self._integrate(
model,
t_eval[start_index:end_index],
model_inputs_list[0],
t_interp=t_interp,
)
p.close()
p.join()
new_solutions = [new_solution]
else:
with mp.get_context(self._mp_context).Pool(processes=nproc) as p:
new_solutions = p.starmap(
self._integrate,
zip(
[model] * ninputs,
[t_eval[start_index:end_index]] * ninputs,
model_inputs_list,
[t_interp] * ninputs,
),
)
p.close()
p.join()
# Setting the solve time for each segment.
# pybamm.Solution.__add__ assumes attribute solve_time.
solve_time = timer.time()
Expand Down Expand Up @@ -995,7 +1000,7 @@ def solve(
)

# Return solution(s)
if ninputs == 1:
if len(solutions) == 1:
return solutions[0]
else:
return solutions
Expand Down Expand Up @@ -1350,7 +1355,13 @@ def step(
# Step
pybamm.logger.verbose(f"Stepping for {t_start_shifted:.0f} < t < {t_end:.0f}")
timer.reset()
solution = self._integrate(model, t_eval, model_inputs, t_interp)

# API for _integrate is different for JaxSolver and IDAKLUSolver
if self.supports_parallel_solve:
solutions = self._integrate(model, t_eval, [model_inputs], t_interp)
solution = solutions[0]
else:
solution = self._integrate(model, t_eval, model_inputs, t_interp)
solution.solve_time = timer.time()

# Check if extrapolation occurred
Expand Down
15 changes: 9 additions & 6 deletions src/pybamm/solvers/c_solvers/idaklu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <pybind11/stl_bind.h>

#include "idaklu/idaklu_solver.hpp"
#include "idaklu/IDAKLUSolverGroup.hpp"
#include "idaklu/IdakluJax.hpp"
#include "idaklu/common.hpp"
#include "idaklu/Expressions/Casadi/CasadiFunctions.hpp"
Expand All @@ -26,15 +27,17 @@ casadi::Function generate_casadi_function(const std::string &data)
namespace py = pybind11;

PYBIND11_MAKE_OPAQUE(std::vector<np_array>);
PYBIND11_MAKE_OPAQUE(std::vector<Solution>);

PYBIND11_MODULE(idaklu, m)
{
m.doc() = "sundials solvers"; // optional module docstring

py::bind_vector<std::vector<np_array>>(m, "VectorNdArray");
py::bind_vector<std::vector<Solution>>(m, "VectorSolution");

py::class_<IDAKLUSolver>(m, "IDAKLUSolver")
.def("solve", &IDAKLUSolver::solve,
py::class_<IDAKLUSolverGroup>(m, "IDAKLUSolverGroup")
.def("solve", &IDAKLUSolverGroup::solve,
"perform a solve",
py::arg("t_eval"),
py::arg("t_interp"),
Expand All @@ -43,8 +46,8 @@ PYBIND11_MODULE(idaklu, m)
py::arg("inputs"),
py::return_value_policy::take_ownership);

m.def("create_casadi_solver", &create_idaklu_solver<CasadiFunctions>,
"Create a casadi idaklu solver object",
m.def("create_casadi_solver_group", &create_idaklu_solver_group<CasadiFunctions>,
"Create a group of casadi idaklu solver objects",
py::arg("number_of_states"),
py::arg("number_of_parameters"),
py::arg("rhs_alg"),
Expand All @@ -70,8 +73,8 @@ PYBIND11_MODULE(idaklu, m)
py::return_value_policy::take_ownership);

#ifdef IREE_ENABLE
m.def("create_iree_solver", &create_idaklu_solver<IREEFunctions>,
"Create a iree idaklu solver object",
m.def("create_iree_solver_group", &create_idaklu_solver_group<IREEFunctions>,
"Create a group of iree idaklu solver objects",
py::arg("number_of_states"),
py::arg("number_of_parameters"),
py::arg("rhs_alg"),
Expand Down
20 changes: 12 additions & 8 deletions src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
#define PYBAMM_IDAKLU_CASADI_SOLVER_HPP

#include "common.hpp"
#include "Solution.hpp"
#include "SolutionData.hpp"


/**
* Abstract base class for solutions that can use different solvers and vector
Expand All @@ -24,14 +25,17 @@ class IDAKLUSolver
~IDAKLUSolver() = default;

/**
* @brief Abstract solver method that returns a Solution class
* @brief Abstract solver method that executes the solver
*/
virtual Solution solve(
np_array t_eval_np,
np_array t_interp_np,
np_array y0_np,
np_array yp0_np,
np_array_dense inputs) = 0;
virtual SolutionData solve(
const std::vector<realtype> &t_eval,
const std::vector<realtype> &t_interp,
const realtype *y0,
const realtype *yp0,
const realtype *inputs,
bool save_adaptive_steps,
bool save_interp_steps
) = 0;

/**
* Abstract method to initialize the solver, once vectors and solver classes
Expand Down
145 changes: 145 additions & 0 deletions src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverGroup.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#include "IDAKLUSolverGroup.hpp"
#include <omp.h>
#include <optional>

std::vector<Solution> IDAKLUSolverGroup::solve(
np_array t_eval_np,
np_array t_interp_np,
np_array y0_np,
np_array yp0_np,
np_array inputs) {
DEBUG("IDAKLUSolverGroup::solve");

// If t_interp is empty, save all adaptive steps
bool save_adaptive_steps = t_interp_np.size() == 0;

const realtype* t_eval_begin = t_eval_np.data();
const realtype* t_eval_end = t_eval_begin + t_eval_np.size();
const realtype* t_interp_begin = t_interp_np.data();
const realtype* t_interp_end = t_interp_begin + t_interp_np.size();

// Process the time inputs
// 1. Get the sorted and unique t_eval vector
auto const t_eval = makeSortedUnique(t_eval_begin, t_eval_end);

// 2.1. Get the sorted and unique t_interp vector
auto const t_interp_unique_sorted = makeSortedUnique(t_interp_begin, t_interp_end);

// 2.2 Remove the t_eval values from t_interp
auto const t_interp_setdiff = setDiff(t_interp_unique_sorted.begin(), t_interp_unique_sorted.end(), t_eval_begin, t_eval_end);

// 2.3 Finally, get the sorted and unique t_interp vector with t_eval values removed
auto const t_interp = makeSortedUnique(t_interp_setdiff.begin(), t_interp_setdiff.end());

int const number_of_evals = t_eval.size();
int const number_of_interps = t_interp.size();

// setDiff removes entries of t_interp that overlap with
// t_eval, so we need to check if we need to interpolate any unique points.
// This is not the same as save_adaptive_steps since some entries of t_interp
// may be removed by setDiff
bool save_interp_steps = number_of_interps > 0;

// 3. Check if the timestepping entries are valid
if (number_of_evals < 2) {
throw std::invalid_argument(
"t_eval must have at least 2 entries"
);
} else if (save_interp_steps) {
if (t_interp.front() < t_eval.front()) {
throw std::invalid_argument(
"t_interp values must be greater than the smallest t_eval value: "
+ std::to_string(t_eval.front())
);
} else if (t_interp.back() > t_eval.back()) {
throw std::invalid_argument(
"t_interp values must be less than the greatest t_eval value: "
+ std::to_string(t_eval.back())
);
}
}

auto n_coeffs = number_of_states + number_of_parameters * number_of_states;

// check y0 and yp0 and inputs have the correct dimensions
if (y0_np.ndim() != 2)
throw std::domain_error("y0 has wrong number of dimensions. Expected 2 but got " + std::to_string(y0_np.ndim()));
if (yp0_np.ndim() != 2)
throw std::domain_error("yp0 has wrong number of dimensions. Expected 2 but got " + std::to_string(yp0_np.ndim()));
if (inputs.ndim() != 2)
throw std::domain_error("inputs has wrong number of dimensions. Expected 2 but got " + std::to_string(inputs.ndim()));

auto number_of_groups = y0_np.shape()[0];

// check y0 and yp0 and inputs have the correct shape
if (y0_np.shape()[1] != n_coeffs)
throw std::domain_error(
"y0 has wrong number of cols. Expected " + std::to_string(n_coeffs) +
" but got " + std::to_string(y0_np.shape()[1]));

if (yp0_np.shape()[1] != n_coeffs)
throw std::domain_error(
"yp0 has wrong number of cols. Expected " + std::to_string(n_coeffs) +
" but got " + std::to_string(yp0_np.shape()[1]));

if (yp0_np.shape()[0] != number_of_groups)
throw std::domain_error(
"yp0 has wrong number of rows. Expected " + std::to_string(number_of_groups) +
" but got " + std::to_string(yp0_np.shape()[0]));

if (inputs.shape()[0] != number_of_groups)
throw std::domain_error(
"inputs has wrong number of rows. Expected " + std::to_string(number_of_groups) +
" but got " + std::to_string(inputs.shape()[0]));

const std::size_t solves_per_thread = number_of_groups / m_solvers.size();
const std::size_t remainder_solves = number_of_groups % m_solvers.size();

const realtype *y0 = y0_np.data();
const realtype *yp0 = yp0_np.data();
const realtype *inputs_data = inputs.data();

std::vector<SolutionData> results(number_of_groups);

std::optional<std::exception> exception;

omp_set_num_threads(m_solvers.size());
#pragma omp parallel for
for (int i = 0; i < m_solvers.size(); i++) {
try {
for (int j = 0; j < solves_per_thread; j++) {
const std::size_t index = i * solves_per_thread + j;
const realtype *y = y0 + index * y0_np.shape(1);
const realtype *yp = yp0 + index * yp0_np.shape(1);
const realtype *input = inputs_data + index * inputs.shape(1);
results[index] = m_solvers[i]->solve(t_eval, t_interp, y, yp, input, save_adaptive_steps, save_interp_steps);
}
} catch (std::exception &e) {
// If an exception is thrown, we need to catch it and rethrow it outside the parallel region
#pragma omp critical
{
exception = e;
}
}
}

if (exception.has_value()) {
py::set_error(PyExc_ValueError, exception->what());
throw py::error_already_set();
}

for (int i = 0; i < remainder_solves; i++) {
const std::size_t index = number_of_groups - remainder_solves + i;
const realtype *y = y0 + index * y0_np.shape(1);
const realtype *yp = yp0 + index * yp0_np.shape(1);
const realtype *input = inputs_data + index * inputs.shape(1);
results[index] = m_solvers[i]->solve(t_eval, t_interp, y, yp, input, save_adaptive_steps, save_interp_steps);
}

// create solutions (needs to be serial as we're using the Python GIL)
std::vector<Solution> solutions(number_of_groups);
for (int i = 0; i < number_of_groups; i++) {
solutions[i] = results[i].generate_solution();
}
return solutions;
}
Loading

0 comments on commit 48dbb68

Please sign in to comment.